Bill Psomas commited on
Commit
aeb9733
1 Parent(s): 86c9545

demo updated

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ produce_attmaps.py
2
+ __pycache__
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import PIL
3
+ import ast
4
+ import cv2
5
+ import json
6
+ import torch
7
+ import pickle
8
+ import torchvision
9
+ import numpy as np
10
+ import gradio as gr
11
+ from PIL import Image
12
+ from typing import Tuple, Dict
13
+ import matplotlib.pyplot as plt
14
+ from timeit import default_timer as timer
15
+ from torchvision import datasets, transforms
16
+
17
+ import vision_transformer as vits
18
+
19
+ '''
20
+ import warnings
21
+ warnings.filterwarnings('ignore')
22
+
23
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
24
+
25
+ with open('labels/imagenet1k-simple-labels.json') as f:
26
+ class_names = json.load(f)
27
+
28
+ from model import VisionTransformer
29
+ from capture_weights import vit_weights
30
+ '''
31
+
32
+ arch = "vit_small"
33
+ mode = "simpool"
34
+ gamma = None
35
+ patch_size = 16
36
+ input_size = 224
37
+ num_classes = 0
38
+ checkpoint = "checkpoints/vits_dino_simpool_no_gamma_ep100.pth"
39
+ checkpoint_key = "teacher"
40
+
41
+ cm = plt.get_cmap('viridis')
42
+ attn_map_size = 224
43
+ width_display = 300
44
+ height_display = 300
45
+
46
+ example_dir = "examples/"
47
+ example_list = [[example_dir + example] for example in os.listdir(example_dir)]
48
+ #example_list = "n03017168_54500.JPEG"
49
+
50
+ # Load model
51
+ model = vits.__dict__[arch](
52
+ mode=mode,
53
+ gamma=gamma,
54
+ patch_size=patch_size,
55
+ num_classes=num_classes,
56
+ )
57
+ state_dict = torch.load(checkpoint)
58
+ state_dict = state_dict[checkpoint_key]
59
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
60
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
61
+ state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()}
62
+ msg = model.load_state_dict(state_dict, strict=True)
63
+
64
+ model.eval()
65
+
66
+ # Define transformations
67
+ data_transforms = transforms.Compose([
68
+ transforms.Resize((input_size, input_size), interpolation=3),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
71
+ ])
72
+
73
+ def get_attention_map(img):
74
+ x = data_transforms(img)
75
+ attn = model.get_simpool_attention(x[None, :, :, :])
76
+ attn = attn.reshape(1, 1, input_size//patch_size, input_size//patch_size)
77
+ attn = attn/attn.sum()
78
+ attn = attn.squeeze()
79
+ attn = (attn-(attn).min())/((attn).max()-(attn).min())
80
+ attn = torch.threshold(attn, 0.1, 0)
81
+
82
+ attn_img = Image.fromarray(np.uint8(cm(attn.detach().numpy())*255)).convert('RGB')
83
+ attn_img = attn_img.resize((attn_map_size, attn_map_size), resample=Image.NEAREST)
84
+ return attn_img
85
+
86
+ attention_interface = gr.Interface(
87
+ fn=get_attention_map,
88
+ inputs=[gr.Image(type="pil", label="Input Image")],
89
+ outputs=gr.Image(type="pil", label="SimPool Attention Map", width=width_display, height=height_display),
90
+ examples=example_list,
91
+ title="Explore the Attention Maps of SimPool🔍",
92
+ description="Upload or use one of the selected images to explore the intricate focus areas of a ViT-S model with SimPool, trained on ImageNet-1k, under supervision."
93
+ )
94
+
95
+ demo = gr.TabbedInterface([attention_interface],
96
+ ["Visualize Attention Maps"], title="SimPool Attention Map Visualizer 🌌")
97
+
98
+ if __name__ == "__main__":
99
+ demo.launch()
checkpoints/vits_dino_simpool_no_gamma_ep100.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dec0fd06409a629b7ca975ec2ed7b124568230893daec3358ffe1b7b5f7127e6
3
+ size 709505239
examples/ILSVRC2012_val_00001391_orig.PNG ADDED
examples/ILSVRC2012_val_00001398_orig.PNG ADDED
examples/ILSVRC2012_val_00002311_orig.PNG ADDED
examples/ILSVRC2012_val_00003762_orig.PNG ADDED
examples/ILSVRC2012_val_00023778_orig.PNG ADDED
examples/ILSVRC2012_val_00025900_orig.PNG ADDED
examples/ILSVRC2012_val_00037106_orig.PNG ADDED
examples/ILSVRC2012_val_00038638_orig.PNG ADDED
examples/ILSVRC2012_val_00042586_orig.PNG ADDED
examples/ILSVRC2012_val_00049604_orig.PNG ADDED
sp.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class SimPool(nn.Module):
5
+ def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, gamma=None, use_beta=False):
6
+ super().__init__()
7
+ self.num_heads = num_heads
8
+ head_dim = dim // num_heads
9
+ self.scale = qk_scale or head_dim ** -0.5
10
+
11
+ self.norm_patches = nn.LayerNorm(dim, eps=1e-6)
12
+
13
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
14
+ self.wk = nn.Linear(dim, dim, bias=qkv_bias)
15
+
16
+ if gamma is not None:
17
+ self.gamma = torch.tensor([gamma], device='cuda')
18
+ if use_beta:
19
+ self.beta = nn.Parameter(torch.tensor([0.0], device='cuda'))
20
+ self.eps = torch.tensor([1e-6], device='cuda')
21
+
22
+ self.gamma = gamma
23
+ self.use_beta = use_beta
24
+
25
+ def prepare_input(self, x):
26
+ if len(x.shape) == 3: # Transformer
27
+ # Input tensor dimensions:
28
+ # x: (B, N, d), where B is batch size, N are patch tokens, d is depth (channels)
29
+ B, N, d = x.shape
30
+ gap_cls = x.mean(-2) # (B, N, d) -> (B, d)
31
+ gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
32
+ return gap_cls, x
33
+ if len(x.shape) == 4: # CNN
34
+ # Input tensor dimensions:
35
+ # x: (B, d, H, W), where B is batch size, d is depth (channels), H is height, and W is width
36
+ B, d, H, W = x.shape
37
+ gap_cls = x.mean([-2, -1]) # (B, d, H, W) -> (B, d)
38
+ x = x.reshape(B, d, H*W).permute(0, 2, 1) # (B, d, H, W) -> (B, d, H*W) -> (B, H*W, d)
39
+ gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
40
+ return gap_cls, x
41
+ else:
42
+ raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}")
43
+
44
+ def forward(self, x):
45
+ # Prepare input tensor and perform GAP as initialization
46
+ gap_cls, x = self.prepare_input(x)
47
+
48
+ # Prepare queries (q), keys (k), and values (v)
49
+ q, k, v = gap_cls, self.norm_patches(x), self.norm_patches(x)
50
+
51
+ # Extract dimensions after normalization
52
+ Bq, Nq, dq = q.shape
53
+ Bk, Nk, dk = k.shape
54
+ Bv, Nv, dv = v.shape
55
+
56
+ # Check dimension consistency across batches and channels
57
+ assert Bq == Bk == Bv
58
+ assert dq == dk == dv
59
+
60
+ # Apply linear transformation for queries and keys then reshape
61
+ qq = self.wq(q).reshape(Bq, Nq, self.num_heads, dq // self.num_heads).permute(0, 2, 1, 3) # (Bq, Nq, dq) -> (B, num_heads, Nq, dq/num_heads)
62
+ kk = self.wk(k).reshape(Bk, Nk, self.num_heads, dk // self.num_heads).permute(0, 2, 1, 3) # (Bk, Nk, dk) -> (B, num_heads, Nk, dk/num_heads)
63
+
64
+ vv = v.reshape(Bv, Nv, self.num_heads, dv // self.num_heads).permute(0, 2, 1, 3) # (Bv, Nv, dv) -> (B, num_heads, Nv, dv/num_heads)
65
+
66
+ # Compute attention scores
67
+ attn = (qq @ kk.transpose(-2, -1)) * self.scale
68
+ # Apply softmax for normalization
69
+ attn = attn.softmax(dim=-1)
70
+
71
+ # If gamma scaling is used
72
+ if self.gamma is not None:
73
+ # Apply gamma scaling on values and compute the weighted sum using attention scores
74
+ x = torch.pow(attn @ torch.pow((vv - vv.min() + self.eps), self.gamma), 1/self.gamma) # (B, num_heads, Nv, dv/num_heads) -> (B, 1, 1, d)
75
+ # If use_beta, add a learnable translation
76
+ if self.use_beta:
77
+ x = x + self.beta
78
+ else:
79
+ # Compute the weighted sum using attention scores
80
+ x = (attn @ vv).transpose(1, 2).reshape(Bq, Nq, dq)
81
+
82
+ return attn
utils.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """
5
+ Misc functions.
6
+
7
+ Mostly copy-paste from torchvision references or other public repos like DETR:
8
+ https://github.com/facebookresearch/detr
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import time
14
+ import math
15
+ import random
16
+ import datetime
17
+ import subprocess
18
+ from collections import defaultdict, deque
19
+
20
+ import numpy as np
21
+ import torch
22
+ from torch import nn
23
+ import torch.distributed as dist
24
+ from PIL import ImageFilter, ImageOps
25
+
26
+
27
+ class GaussianBlur(object):
28
+ """
29
+ Apply Gaussian Blur to the PIL image.
30
+ """
31
+ def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
32
+ self.prob = p
33
+ self.radius_min = radius_min
34
+ self.radius_max = radius_max
35
+
36
+ def __call__(self, img):
37
+ do_it = random.random() <= self.prob
38
+ if not do_it:
39
+ return img
40
+
41
+ return img.filter(
42
+ ImageFilter.GaussianBlur(
43
+ radius=random.uniform(self.radius_min, self.radius_max)
44
+ )
45
+ )
46
+
47
+
48
+ class Solarization(object):
49
+ """
50
+ Apply Solarization to the PIL image.
51
+ """
52
+ def __init__(self, p):
53
+ self.p = p
54
+
55
+ def __call__(self, img):
56
+ if random.random() < self.p:
57
+ return ImageOps.solarize(img)
58
+ else:
59
+ return img
60
+
61
+
62
+ def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
63
+ if os.path.isfile(pretrained_weights):
64
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
65
+ if checkpoint_key is not None and checkpoint_key in state_dict:
66
+ print(f"Take key {checkpoint_key} in provided checkpoint dict")
67
+ state_dict = state_dict[checkpoint_key]
68
+
69
+ # remove `module.` prefix
70
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
71
+ # remove `backbone.` prefix induced by multicrop wrapper
72
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
73
+
74
+ # Filter out unnecessary keys
75
+ state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()}
76
+
77
+ msg = model.load_state_dict(state_dict, strict=True)
78
+ print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
79
+ else:
80
+ print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
81
+ url = None
82
+ if model_name == "vit_small" and patch_size == 16:
83
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
84
+ elif model_name == "vit_small" and patch_size == 8:
85
+ url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
86
+ elif model_name == "vit_base" and patch_size == 16:
87
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
88
+ elif model_name == "vit_base" and patch_size == 8:
89
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
90
+ elif model_name == "xcit_small_12_p16":
91
+ url = "dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth"
92
+ elif model_name == "xcit_small_12_p8":
93
+ url = "dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth"
94
+ elif model_name == "xcit_medium_24_p16":
95
+ url = "dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth"
96
+ elif model_name == "xcit_medium_24_p8":
97
+ url = "dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth"
98
+ elif model_name == "resnet50":
99
+ url = "dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
100
+ if url is not None:
101
+ print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
102
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
103
+ model.load_state_dict(state_dict, strict=True)
104
+ else:
105
+ print("There is no reference weights available for this model => We use random weights.")
106
+
107
+
108
+ def load_pretrained_linear_weights(linear_classifier, model_name, patch_size):
109
+ url = None
110
+ if model_name == "vit_small" and patch_size == 16:
111
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth"
112
+ elif model_name == "vit_small" and patch_size == 8:
113
+ url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth"
114
+ elif model_name == "vit_base" and patch_size == 16:
115
+ url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth"
116
+ elif model_name == "vit_base" and patch_size == 8:
117
+ url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth"
118
+ elif model_name == "resnet50":
119
+ url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth"
120
+ if url is not None:
121
+ print("We load the reference pretrained linear weights.")
122
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)["state_dict"]
123
+ linear_classifier.load_state_dict(state_dict, strict=True)
124
+ else:
125
+ print("We use random linear weights.")
126
+
127
+
128
+ def clip_gradients(model, clip):
129
+ norms = []
130
+ for name, p in model.named_parameters():
131
+ if p.grad is not None:
132
+ param_norm = p.grad.data.norm(2)
133
+ norms.append(param_norm.item())
134
+ clip_coef = clip / (param_norm + 1e-6)
135
+ if clip_coef < 1:
136
+ p.grad.data.mul_(clip_coef)
137
+ return norms
138
+
139
+
140
+ def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
141
+ if epoch >= freeze_last_layer:
142
+ return
143
+ for n, p in model.named_parameters():
144
+ if "last_layer" in n:
145
+ p.grad = None
146
+
147
+
148
+ def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
149
+ """
150
+ Re-start from checkpoint
151
+ """
152
+ if not os.path.isfile(ckp_path):
153
+ return
154
+ print("Found checkpoint at {}".format(ckp_path))
155
+
156
+ # open checkpoint file
157
+ checkpoint = torch.load(ckp_path, map_location="cpu")
158
+
159
+ # key is what to look for in the checkpoint file
160
+ # value is the object to load
161
+ # example: {'state_dict': model}
162
+ for key, value in kwargs.items():
163
+ if key in checkpoint and value is not None:
164
+ try:
165
+ msg = value.load_state_dict(checkpoint[key], strict=False)
166
+ print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
167
+ except TypeError:
168
+ try:
169
+ msg = value.load_state_dict(checkpoint[key])
170
+ print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
171
+ except ValueError:
172
+ print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
173
+ else:
174
+ print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
175
+
176
+ # re load variable important for the run
177
+ if run_variables is not None:
178
+ for var_name in run_variables:
179
+ if var_name in checkpoint:
180
+ run_variables[var_name] = checkpoint[var_name]
181
+
182
+
183
+ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
184
+ warmup_schedule = np.array([])
185
+ warmup_iters = warmup_epochs * niter_per_ep
186
+ if warmup_epochs > 0:
187
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
188
+
189
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
190
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
191
+
192
+ schedule = np.concatenate((warmup_schedule, schedule))
193
+ assert len(schedule) == epochs * niter_per_ep
194
+ return schedule
195
+
196
+
197
+ def bool_flag(s):
198
+ """
199
+ Parse boolean arguments from the command line.
200
+ """
201
+ FALSY_STRINGS = {"off", "false", "0"}
202
+ TRUTHY_STRINGS = {"on", "true", "1"}
203
+ if s.lower() in FALSY_STRINGS:
204
+ return False
205
+ elif s.lower() in TRUTHY_STRINGS:
206
+ return True
207
+ else:
208
+ raise argparse.ArgumentTypeError("invalid value for a boolean flag")
209
+
210
+
211
+ def fix_random_seeds(seed=31):
212
+ """
213
+ Fix random seeds.
214
+ """
215
+ torch.manual_seed(seed)
216
+ torch.cuda.manual_seed_all(seed)
217
+ np.random.seed(seed)
218
+
219
+
220
+ class SmoothedValue(object):
221
+ """Track a series of values and provide access to smoothed values over a
222
+ window or the global series average.
223
+ """
224
+
225
+ def __init__(self, window_size=20, fmt=None):
226
+ if fmt is None:
227
+ fmt = "{median:.6f} ({global_avg:.6f})"
228
+ self.deque = deque(maxlen=window_size)
229
+ self.total = 0.0
230
+ self.count = 0
231
+ self.fmt = fmt
232
+
233
+ def update(self, value, n=1):
234
+ self.deque.append(value)
235
+ self.count += n
236
+ self.total += value * n
237
+
238
+ def synchronize_between_processes(self):
239
+ """
240
+ Warning: does not synchronize the deque!
241
+ """
242
+ if not is_dist_avail_and_initialized():
243
+ return
244
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
245
+ dist.barrier()
246
+ dist.all_reduce(t)
247
+ t = t.tolist()
248
+ self.count = int(t[0])
249
+ self.total = t[1]
250
+
251
+ @property
252
+ def median(self):
253
+ d = torch.tensor(list(self.deque))
254
+ return d.median().item()
255
+
256
+ @property
257
+ def avg(self):
258
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
259
+ return d.mean().item()
260
+
261
+ @property
262
+ def global_avg(self):
263
+ return self.total / self.count
264
+
265
+ @property
266
+ def max(self):
267
+ return max(self.deque)
268
+
269
+ @property
270
+ def value(self):
271
+ return self.deque[-1]
272
+
273
+ def __str__(self):
274
+ return self.fmt.format(
275
+ median=self.median,
276
+ avg=self.avg,
277
+ global_avg=self.global_avg,
278
+ max=self.max,
279
+ value=self.value)
280
+
281
+
282
+ def reduce_dict(input_dict, average=True):
283
+ """
284
+ Args:
285
+ input_dict (dict): all the values will be reduced
286
+ average (bool): whether to do average or sum
287
+ Reduce the values in the dictionary from all processes so that all processes
288
+ have the averaged results. Returns a dict with the same fields as
289
+ input_dict, after reduction.
290
+ """
291
+ world_size = get_world_size()
292
+ if world_size < 2:
293
+ return input_dict
294
+ with torch.no_grad():
295
+ names = []
296
+ values = []
297
+ # sort the keys so that they are consistent across processes
298
+ for k in sorted(input_dict.keys()):
299
+ names.append(k)
300
+ values.append(input_dict[k])
301
+ values = torch.stack(values, dim=0)
302
+ dist.all_reduce(values)
303
+ if average:
304
+ values /= world_size
305
+ reduced_dict = {k: v for k, v in zip(names, values)}
306
+ return reduced_dict
307
+
308
+
309
+ class MetricLogger(object):
310
+ def __init__(self, delimiter="\t"):
311
+ self.meters = defaultdict(SmoothedValue)
312
+ self.delimiter = delimiter
313
+
314
+ def update(self, **kwargs):
315
+ for k, v in kwargs.items():
316
+ if isinstance(v, torch.Tensor):
317
+ v = v.item()
318
+ assert isinstance(v, (float, int))
319
+ self.meters[k].update(v)
320
+
321
+ def __getattr__(self, attr):
322
+ if attr in self.meters:
323
+ return self.meters[attr]
324
+ if attr in self.__dict__:
325
+ return self.__dict__[attr]
326
+ raise AttributeError("'{}' object has no attribute '{}'".format(
327
+ type(self).__name__, attr))
328
+
329
+ def __str__(self):
330
+ loss_str = []
331
+ for name, meter in self.meters.items():
332
+ loss_str.append(
333
+ "{}: {}".format(name, str(meter))
334
+ )
335
+ return self.delimiter.join(loss_str)
336
+
337
+ def synchronize_between_processes(self):
338
+ for meter in self.meters.values():
339
+ meter.synchronize_between_processes()
340
+
341
+ def add_meter(self, name, meter):
342
+ self.meters[name] = meter
343
+
344
+ def log_every(self, iterable, print_freq, header=None):
345
+ i = 0
346
+ if not header:
347
+ header = ''
348
+ start_time = time.time()
349
+ end = time.time()
350
+ iter_time = SmoothedValue(fmt='{avg:.6f}')
351
+ data_time = SmoothedValue(fmt='{avg:.6f}')
352
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
353
+ if torch.cuda.is_available():
354
+ log_msg = self.delimiter.join([
355
+ header,
356
+ '[{0' + space_fmt + '}/{1}]',
357
+ 'eta: {eta}',
358
+ '{meters}',
359
+ 'time: {time}',
360
+ 'data: {data}',
361
+ 'max mem: {memory:.0f}'
362
+ ])
363
+ else:
364
+ log_msg = self.delimiter.join([
365
+ header,
366
+ '[{0' + space_fmt + '}/{1}]',
367
+ 'eta: {eta}',
368
+ '{meters}',
369
+ 'time: {time}',
370
+ 'data: {data}'
371
+ ])
372
+ MB = 1024.0 * 1024.0
373
+ for obj in iterable:
374
+ data_time.update(time.time() - end)
375
+ yield obj
376
+ iter_time.update(time.time() - end)
377
+ if i % print_freq == 0 or i == len(iterable) - 1:
378
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
379
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
380
+ if torch.cuda.is_available():
381
+ print(log_msg.format(
382
+ i, len(iterable), eta=eta_string,
383
+ meters=str(self),
384
+ time=str(iter_time), data=str(data_time),
385
+ memory=torch.cuda.max_memory_allocated() / MB))
386
+ else:
387
+ print(log_msg.format(
388
+ i, len(iterable), eta=eta_string,
389
+ meters=str(self),
390
+ time=str(iter_time), data=str(data_time)))
391
+ i += 1
392
+ end = time.time()
393
+ total_time = time.time() - start_time
394
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
395
+ print('{} Total time: {} ({:.6f} s / it)'.format(
396
+ header, total_time_str, total_time / len(iterable)))
397
+
398
+
399
+ def get_sha():
400
+ cwd = os.path.dirname(os.path.abspath(__file__))
401
+
402
+ def _run(command):
403
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
404
+ sha = 'N/A'
405
+ diff = "clean"
406
+ branch = 'N/A'
407
+ try:
408
+ sha = _run(['git', 'rev-parse', 'HEAD'])
409
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
410
+ diff = _run(['git', 'diff-index', 'HEAD'])
411
+ diff = "has uncommited changes" if diff else "clean"
412
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
413
+ except Exception:
414
+ pass
415
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
416
+ return message
417
+
418
+
419
+ def is_dist_avail_and_initialized():
420
+ if not dist.is_available():
421
+ return False
422
+ if not dist.is_initialized():
423
+ return False
424
+ return True
425
+
426
+
427
+ def get_world_size():
428
+ if not is_dist_avail_and_initialized():
429
+ return 1
430
+ return dist.get_world_size()
431
+
432
+
433
+ def get_rank():
434
+ if not is_dist_avail_and_initialized():
435
+ return 0
436
+ return dist.get_rank()
437
+
438
+
439
+ def is_main_process():
440
+ return get_rank() == 0
441
+
442
+
443
+ def save_on_master(*args, **kwargs):
444
+ if is_main_process():
445
+ torch.save(*args, **kwargs)
446
+
447
+
448
+ def setup_for_distributed(is_master):
449
+ """
450
+ This function disables printing when not in master process
451
+ """
452
+ import builtins as __builtin__
453
+ builtin_print = __builtin__.print
454
+
455
+ def print(*args, **kwargs):
456
+ force = kwargs.pop('force', False)
457
+ if is_master or force:
458
+ builtin_print(*args, **kwargs)
459
+
460
+ __builtin__.print = print
461
+
462
+
463
+ def init_distributed_mode(args):
464
+ # launched with torch.distributed.launch
465
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
466
+ args.rank = int(os.environ["RANK"])
467
+ args.world_size = int(os.environ['WORLD_SIZE'])
468
+ args.gpu = int(os.environ['LOCAL_RANK'])
469
+ # launched with submitit on a slurm cluster
470
+ elif 'SLURM_PROCID' in os.environ:
471
+ args.rank = int(os.environ['SLURM_PROCID'])
472
+ args.gpu = args.rank % torch.cuda.device_count()
473
+ # launched naively with `python main_dino.py`
474
+ # we manually add MASTER_ADDR and MASTER_PORT to env variables
475
+ elif torch.cuda.is_available():
476
+ print('Will run the code on one GPU.')
477
+ args.rank, args.gpu, args.world_size = 0, 0, 1
478
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
479
+ os.environ['MASTER_PORT'] = '29500'
480
+ else:
481
+ print('Does not support training without GPU.')
482
+ sys.exit(1)
483
+
484
+ dist.init_process_group(
485
+ backend=args.backend,
486
+ init_method=args.dist_url,
487
+ world_size=args.world_size,
488
+ rank=args.rank,
489
+ )
490
+
491
+ torch.cuda.set_device(args.gpu)
492
+ print('| distributed init (rank {}): {}'.format(
493
+ args.rank, args.dist_url), flush=True)
494
+ dist.barrier()
495
+ setup_for_distributed(args.rank == 0)
496
+
497
+
498
+ def accuracy(output, target, topk=(1,)):
499
+ """Computes the accuracy over the k top predictions for the specified values of k"""
500
+ maxk = max(topk)
501
+ batch_size = target.size(0)
502
+ _, pred = output.topk(maxk, 1, True, True)
503
+ pred = pred.t()
504
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
505
+ return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
506
+
507
+
508
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
509
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
510
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
511
+ def norm_cdf(x):
512
+ # Computes standard normal cumulative distribution function
513
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
514
+
515
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
516
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
517
+ "The distribution of values may be incorrect.",
518
+ stacklevel=2)
519
+
520
+ with torch.no_grad():
521
+ # Values are generated by using a truncated uniform distribution and
522
+ # then using the inverse CDF for the normal distribution.
523
+ # Get upper and lower cdf values
524
+ l = norm_cdf((a - mean) / std)
525
+ u = norm_cdf((b - mean) / std)
526
+
527
+ # Uniformly fill tensor with values from [l, u], then translate to
528
+ # [2l-1, 2u-1].
529
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
530
+
531
+ # Use inverse cdf transform for normal distribution to get truncated
532
+ # standard normal
533
+ tensor.erfinv_()
534
+
535
+ # Transform to proper mean, std
536
+ tensor.mul_(std * math.sqrt(2.))
537
+ tensor.add_(mean)
538
+
539
+ # Clamp to ensure it's in the proper range
540
+ tensor.clamp_(min=a, max=b)
541
+ return tensor
542
+
543
+
544
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
545
+ # type: (Tensor, float, float, float, float) -> Tensor
546
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
547
+
548
+
549
+ class LARS(torch.optim.Optimizer):
550
+ """
551
+ Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
552
+ """
553
+ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
554
+ weight_decay_filter=None, lars_adaptation_filter=None):
555
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
556
+ eta=eta, weight_decay_filter=weight_decay_filter,
557
+ lars_adaptation_filter=lars_adaptation_filter)
558
+ super().__init__(params, defaults)
559
+
560
+ @torch.no_grad()
561
+ def step(self):
562
+ for g in self.param_groups:
563
+ for p in g['params']:
564
+ dp = p.grad
565
+
566
+ if dp is None:
567
+ continue
568
+
569
+ if p.ndim != 1:
570
+ dp = dp.add(p, alpha=g['weight_decay'])
571
+
572
+ if p.ndim != 1:
573
+ param_norm = torch.norm(p)
574
+ update_norm = torch.norm(dp)
575
+ one = torch.ones_like(param_norm)
576
+ q = torch.where(param_norm > 0.,
577
+ torch.where(update_norm > 0,
578
+ (g['eta'] * param_norm / update_norm), one), one)
579
+ dp = dp.mul(q)
580
+
581
+ param_state = self.state[p]
582
+ if 'mu' not in param_state:
583
+ param_state['mu'] = torch.zeros_like(p)
584
+ mu = param_state['mu']
585
+ mu.mul_(g['momentum']).add_(dp)
586
+
587
+ p.add_(mu, alpha=-g['lr'])
588
+
589
+
590
+ class MultiCropWrapper(nn.Module):
591
+ """
592
+ Perform forward pass separately on each resolution input.
593
+ The inputs corresponding to a single resolution are clubbed and single
594
+ forward is run on the same resolution inputs. Hence we do several
595
+ forward passes = number of different resolutions used. We then
596
+ concatenate all the output features and run the head forward on these
597
+ concatenated features.
598
+ """
599
+ def __init__(self, backbone, head):
600
+ super(MultiCropWrapper, self).__init__()
601
+ # disable layers dedicated to ImageNet labels classification
602
+ backbone.fc, backbone.head = nn.Identity(), nn.Identity()
603
+ self.backbone = backbone
604
+ self.head = head
605
+
606
+ def forward(self, x):
607
+ # convert to list
608
+ if not isinstance(x, list):
609
+ x = [x]
610
+ idx_crops = torch.cumsum(torch.unique_consecutive(
611
+ torch.tensor([inp.shape[-1] for inp in x]),
612
+ return_counts=True,
613
+ )[1], 0)
614
+ start_idx, output = 0, torch.empty(0).to(x[0].device)
615
+ for end_idx in idx_crops:
616
+ _out = self.backbone(torch.cat(x[start_idx: end_idx]))
617
+ # The output is a tuple with XCiT model. See:
618
+ # https://github.com/facebookresearch/xcit/blob/master/xcit.py#L404-L405
619
+ if isinstance(_out, tuple):
620
+ _out = _out[0]
621
+ # accumulate outputs
622
+ output = torch.cat((output, _out))
623
+ start_idx = end_idx
624
+ # Run the head forward on the concatenated features.
625
+ return self.head(output)
626
+
627
+
628
+ def get_params_groups(model):
629
+ regularized = []
630
+ not_regularized = []
631
+ for name, param in model.named_parameters():
632
+ if not param.requires_grad:
633
+ continue
634
+ # we do not regularize biases nor Norm parameters
635
+ if name.endswith(".bias") or len(param.shape) == 1:
636
+ not_regularized.append(param)
637
+ else:
638
+ regularized.append(param)
639
+ return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
640
+
641
+
642
+ def has_batchnorms(model):
643
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
644
+ for name, module in model.named_modules():
645
+ if isinstance(module, bn_types):
646
+ return True
647
+ return False
648
+
649
+
650
+ class PCA():
651
+ """
652
+ Class to compute and apply PCA.
653
+ """
654
+ def __init__(self, dim=256, whit=0.5):
655
+ self.dim = dim
656
+ self.whit = whit
657
+ self.mean = None
658
+
659
+ def train_pca(self, cov):
660
+ """
661
+ Takes a covariance matrix (np.ndarray) as input.
662
+ """
663
+ d, v = np.linalg.eigh(cov)
664
+ eps = d.max() * 1e-5
665
+ n_0 = (d < eps).sum()
666
+ if n_0 > 0:
667
+ d[d < eps] = eps
668
+
669
+ # total energy
670
+ totenergy = d.sum()
671
+
672
+ # sort eigenvectors with eigenvalues order
673
+ idx = np.argsort(d)[::-1][:self.dim]
674
+ d = d[idx]
675
+ v = v[:, idx]
676
+
677
+ print("keeping %.2f %% of the energy" % (d.sum() / totenergy * 100.0))
678
+
679
+ # for the whitening
680
+ d = np.diag(1. / d**self.whit)
681
+
682
+ # principal components
683
+ self.dvt = np.dot(d, v.T)
684
+
685
+ def apply(self, x):
686
+ # input is from numpy
687
+ if isinstance(x, np.ndarray):
688
+ if self.mean is not None:
689
+ x -= self.mean
690
+ return np.dot(self.dvt, x.T).T
691
+
692
+ # input is from torch and is on GPU
693
+ if x.is_cuda:
694
+ if self.mean is not None:
695
+ x -= torch.cuda.FloatTensor(self.mean)
696
+ return torch.mm(torch.cuda.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
697
+
698
+ # input if from torch, on CPU
699
+ if self.mean is not None:
700
+ x -= torch.FloatTensor(self.mean)
701
+ return torch.mm(torch.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
702
+
703
+
704
+ def compute_ap(ranks, nres):
705
+ """
706
+ Computes average precision for given ranked indexes.
707
+ Arguments
708
+ ---------
709
+ ranks : zerro-based ranks of positive images
710
+ nres : number of positive images
711
+ Returns
712
+ -------
713
+ ap : average precision
714
+ """
715
+
716
+ # number of images ranked by the system
717
+ nimgranks = len(ranks)
718
+
719
+ # accumulate trapezoids in PR-plot
720
+ ap = 0
721
+
722
+ recall_step = 1. / nres
723
+
724
+ for j in np.arange(nimgranks):
725
+ rank = ranks[j]
726
+
727
+ if rank == 0:
728
+ precision_0 = 1.
729
+ else:
730
+ precision_0 = float(j) / rank
731
+
732
+ precision_1 = float(j + 1) / (rank + 1)
733
+
734
+ ap += (precision_0 + precision_1) * recall_step / 2.
735
+
736
+ return ap
737
+
738
+
739
+ def compute_map(ranks, gnd, kappas=[]):
740
+ """
741
+ Computes the mAP for a given set of returned results.
742
+ Usage:
743
+ map = compute_map (ranks, gnd)
744
+ computes mean average precsion (map) only
745
+ map, aps, pr, prs = compute_map (ranks, gnd, kappas)
746
+ computes mean average precision (map), average precision (aps) for each query
747
+ computes mean precision at kappas (pr), precision at kappas (prs) for each query
748
+ Notes:
749
+ 1) ranks starts from 0, ranks.shape = db_size X #queries
750
+ 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
751
+ 3) If there are no positive images for some query, that query is excluded from the evaluation
752
+ """
753
+
754
+ map = 0.
755
+ nq = len(gnd) # number of queries
756
+ aps = np.zeros(nq)
757
+ pr = np.zeros(len(kappas))
758
+ prs = np.zeros((nq, len(kappas)))
759
+ nempty = 0
760
+
761
+ for i in np.arange(nq):
762
+ qgnd = np.array(gnd[i]['ok'])
763
+
764
+ # no positive images, skip from the average
765
+ if qgnd.shape[0] == 0:
766
+ aps[i] = float('nan')
767
+ prs[i, :] = float('nan')
768
+ nempty += 1
769
+ continue
770
+
771
+ try:
772
+ qgndj = np.array(gnd[i]['junk'])
773
+ except:
774
+ qgndj = np.empty(0)
775
+
776
+ # sorted positions of positive and junk images (0 based)
777
+ pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)]
778
+ junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)]
779
+
780
+ k = 0;
781
+ ij = 0;
782
+ if len(junk):
783
+ # decrease positions of positives based on the number of
784
+ # junk images appearing before them
785
+ ip = 0
786
+ while (ip < len(pos)):
787
+ while (ij < len(junk) and pos[ip] > junk[ij]):
788
+ k += 1
789
+ ij += 1
790
+ pos[ip] = pos[ip] - k
791
+ ip += 1
792
+
793
+ # compute ap
794
+ ap = compute_ap(pos, len(qgnd))
795
+ map = map + ap
796
+ aps[i] = ap
797
+
798
+ # compute precision @ k
799
+ pos += 1 # get it to 1-based
800
+ for j in np.arange(len(kappas)):
801
+ kq = min(max(pos), kappas[j]);
802
+ prs[i, j] = (pos <= kq).sum() / kq
803
+ pr = pr + prs[i, :]
804
+
805
+ map = map / (nq - nempty)
806
+ pr = pr / (nq - nempty)
807
+
808
+ return map, aps, pr, prs
809
+
810
+
811
+ def multi_scale(samples, model):
812
+ v = None
813
+ for s in [1, 1/2**(1/2), 1/2]: # we use 3 different scales
814
+ if s == 1:
815
+ inp = samples.clone()
816
+ else:
817
+ inp = nn.functional.interpolate(samples, scale_factor=s, mode='bilinear', align_corners=False)
818
+ feats = model(inp).clone()
819
+ if v is None:
820
+ v = feats
821
+ else:
822
+ v += feats
823
+ v /= 3
824
+ v /= v.norm()
825
+ return v
826
+
827
+ def subset_of_ImageNet_train_split(dataset_train, subset):
828
+ # Copied from Spyros Gidaris (https://github.com/valeoai/obow/blob/3758504f5e058275725c35ca7faca3731572b911/obow/datasets.py#L244)
829
+ assert isinstance(subset, int)
830
+ assert subset > 0
831
+
832
+ all_indices = []
833
+ for _, img_indices in buildLabelIndex(dataset_train.targets).items():
834
+ assert len(img_indices) >= subset
835
+ all_indices += img_indices[:subset]
836
+
837
+ dataset_train.imgs = [dataset_train.imgs[idx] for idx in all_indices]
838
+ dataset_train.samples = [dataset_train.samples[idx] for idx in all_indices]
839
+ dataset_train.targets = [dataset_train.targets[idx] for idx in all_indices]
840
+ assert len(dataset_train) == (subset * 1000)
841
+
842
+ return dataset_train
843
+
844
+ def buildLabelIndex(labels):
845
+ # Copied from Spyros Gidaris (https://github.com/valeoai/obow/blob/3758504f5e058275725c35ca7faca3731572b911/obow/datasets.py#L38)
846
+ label2inds = {}
847
+ for idx, label in enumerate(labels):
848
+ if label not in label2inds:
849
+ label2inds[label] = []
850
+ label2inds[label].append(idx)
851
+
852
+ return label2inds
853
+
854
+ def float_or_none(value):
855
+ # Convert "None" string to actual None type
856
+ if value == 'None':
857
+ return None
858
+ try:
859
+ # Try converting to float
860
+ return float(value)
861
+ except ValueError:
862
+ raise argparse.ArgumentTypeError(f"Invalid float value: '{value}'")
vision_transformer.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """
5
+ Vision Transformer model
6
+
7
+ Mostly copy-paste from timm library: https://github.com/huggingface/pytorch-image-models
8
+ """
9
+
10
+ import math
11
+ from functools import partial
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from utils import trunc_normal_
17
+
18
+ #TODO: Fix this!
19
+ import sys
20
+ from sp import SimPool
21
+
22
+
23
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
24
+ if drop_prob == 0. or not training:
25
+ return x
26
+ keep_prob = 1 - drop_prob
27
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
28
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
29
+ random_tensor.floor_() # binarize
30
+ output = x.div(keep_prob) * random_tensor
31
+ return output
32
+
33
+
34
+ class DropPath(nn.Module):
35
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
36
+ """
37
+ def __init__(self, drop_prob=None):
38
+ super(DropPath, self).__init__()
39
+ self.drop_prob = drop_prob
40
+
41
+ def forward(self, x):
42
+ return drop_path(x, self.drop_prob, self.training)
43
+
44
+
45
+ class Mlp(nn.Module):
46
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
47
+ super().__init__()
48
+ out_features = out_features or in_features
49
+ hidden_features = hidden_features or in_features
50
+ self.fc1 = nn.Linear(in_features, hidden_features)
51
+ self.act = act_layer()
52
+ self.fc2 = nn.Linear(hidden_features, out_features)
53
+ self.drop = nn.Dropout(drop)
54
+
55
+ def forward(self, x):
56
+ x = self.fc1(x)
57
+ x = self.act(x)
58
+ x = self.drop(x)
59
+ x = self.fc2(x)
60
+ x = self.drop(x)
61
+ return x
62
+
63
+
64
+ class Attention(nn.Module):
65
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
66
+ super().__init__()
67
+ self.num_heads = num_heads
68
+ head_dim = dim // num_heads
69
+ self.scale = qk_scale or head_dim ** -0.5
70
+
71
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
72
+ self.attn_drop = nn.Dropout(attn_drop)
73
+ self.proj = nn.Linear(dim, dim)
74
+ self.proj_drop = nn.Dropout(proj_drop)
75
+
76
+ def forward(self, x):
77
+ B, N, C = x.shape
78
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
79
+ q, k, v = qkv[0], qkv[1], qkv[2]
80
+
81
+ attn = (q @ k.transpose(-2, -1)) * self.scale
82
+ attn = attn.softmax(dim=-1)
83
+ attn = self.attn_drop(attn)
84
+
85
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
86
+ x = self.proj(x)
87
+ x = self.proj_drop(x)
88
+ return x, attn
89
+
90
+
91
+ class Block(nn.Module):
92
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
93
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
94
+ super().__init__()
95
+ self.norm1 = norm_layer(dim)
96
+ self.attn = Attention(
97
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ def forward(self, x, return_attention=False):
104
+ y, attn = self.attn(self.norm1(x))
105
+ if return_attention:
106
+ return attn
107
+ x = x + self.drop_path(y)
108
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
109
+ return x
110
+
111
+
112
+ class PatchEmbed(nn.Module):
113
+ """ Image to Patch Embedding
114
+ """
115
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
116
+ super().__init__()
117
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
118
+ self.img_size = img_size
119
+ self.patch_size = patch_size
120
+ self.num_patches = num_patches
121
+
122
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
123
+
124
+ def forward(self, x):
125
+ B, C, H, W = x.shape
126
+ x = self.proj(x).flatten(2).transpose(1, 2)
127
+ return x
128
+
129
+
130
+ class VisionTransformer(nn.Module):
131
+ """ Vision Transformer """
132
+ def __init__(self, mode, gamma=1.25, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
133
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
134
+ drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
135
+ super().__init__()
136
+ self.mode = mode
137
+
138
+ self.num_features = self.embed_dim = embed_dim
139
+
140
+ self.patch_embed = PatchEmbed(
141
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
142
+ num_patches = self.patch_embed.num_patches
143
+
144
+ if mode == 'official':
145
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
146
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
147
+ else:
148
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
149
+
150
+ self.pos_drop = nn.Dropout(p=drop_rate)
151
+
152
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
153
+ self.blocks = nn.ModuleList([
154
+ Block(
155
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
156
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
157
+ for i in range(depth)])
158
+ self.norm = norm_layer(embed_dim)
159
+
160
+ if mode == 'simpool':
161
+ self.simpool = SimPool(embed_dim, num_heads=1, qkv_bias=False, qk_scale=None, gamma=gamma, use_beta=True)
162
+
163
+ # Classifier head
164
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
165
+
166
+ trunc_normal_(self.pos_embed, std=.02)
167
+ if mode == 'official':
168
+ trunc_normal_(self.cls_token, std=.02)
169
+ self.apply(self._init_weights)
170
+
171
+ def _init_weights(self, m):
172
+ if isinstance(m, nn.Linear):
173
+ trunc_normal_(m.weight, std=.02)
174
+ if isinstance(m, nn.Linear) and m.bias is not None:
175
+ nn.init.constant_(m.bias, 0)
176
+ elif isinstance(m, nn.LayerNorm):
177
+ nn.init.constant_(m.bias, 0)
178
+ nn.init.constant_(m.weight, 1.0)
179
+
180
+ def interpolate_pos_encoding(self, x, w, h):
181
+ npatch = x.shape[1] - 1
182
+ if self.mode == 'official':
183
+ N = self.pos_embed.shape[1] - 1
184
+ else:
185
+ N = self.pos_embed.shape[1]
186
+ if npatch == N and w == h:
187
+ return self.pos_embed
188
+ if self.mode == 'official':
189
+ class_pos_embed = self.pos_embed[:, 0]
190
+ patch_pos_embed = self.pos_embed[:, 1:]
191
+ else:
192
+ patch_pos_embed = self.pos_embed
193
+ dim = x.shape[-1]
194
+ w0 = w // self.patch_embed.patch_size
195
+ h0 = h // self.patch_embed.patch_size
196
+ # we add a small number to avoid floating point error in the interpolation
197
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
198
+ w0, h0 = w0 + 0.1, h0 + 0.1
199
+ patch_pos_embed = nn.functional.interpolate(
200
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
201
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
202
+ mode='bicubic',
203
+ )
204
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
205
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
206
+ if self.mode == 'official':
207
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
208
+ else:
209
+ return patch_pos_embed
210
+
211
+ def prepare_tokens(self, x):
212
+ B, nc, w, h = x.shape
213
+ x = self.patch_embed(x) # patch linear embedding
214
+
215
+ if self.mode == 'official':
216
+ # add the [CLS] token to the embed patch tokens
217
+ cls_tokens = self.cls_token.expand(B, -1, -1)
218
+ x = torch.cat((cls_tokens, x), dim=1)
219
+
220
+ # add positional encoding to each token
221
+ x = x + self.interpolate_pos_encoding(x, w, h)
222
+
223
+ return self.pos_drop(x)
224
+
225
+ def forward(self, x):
226
+ x = self.prepare_tokens(x)
227
+ for blk in self.blocks:
228
+ x = blk(x)
229
+
230
+ if self.mode == 'simpool':
231
+ x = self.simpool(x)
232
+ return self.norm(x)
233
+ else:
234
+ x = self.norm(x)
235
+ return x[:, 0]
236
+
237
+ def get_last_selfattention(self, x):
238
+ x = self.prepare_tokens(x)
239
+ for i, blk in enumerate(self.blocks):
240
+ if i < len(self.blocks) - 1:
241
+ x = blk(x)
242
+ else:
243
+ # return attention of the last block
244
+ return blk(x, return_attention=True)
245
+
246
+ def get_block_selfattention(self, x, block_index):
247
+ x = self.prepare_tokens(x)
248
+ for i, blk in enumerate(self.blocks):
249
+ if i == block_index:
250
+ # return attention of the specified block
251
+ return blk(x, return_attention=True)
252
+ x = blk(x)
253
+
254
+ def get_simpool_attention(self, x):
255
+ x = self.prepare_tokens(x)
256
+ for blk in self.blocks:
257
+ x = blk(x)
258
+
259
+ attn = self.simpool(x)
260
+ return attn
261
+ def get_intermediate_layers(self, x, n=1):
262
+ x = self.prepare_tokens(x)
263
+ # we return the output tokens from the `n` last blocks
264
+ output = []
265
+ for i, blk in enumerate(self.blocks):
266
+ x = blk(x)
267
+ if len(self.blocks) - i <= n:
268
+ output.append(self.norm(x))
269
+ return output
270
+
271
+
272
+ def vit_tiny(mode='official', patch_size=16, **kwargs):
273
+ model = VisionTransformer(
274
+ mode=mode, patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
275
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
276
+ return model
277
+
278
+
279
+ def vit_small(mode='official', patch_size=16, **kwargs):
280
+ model = VisionTransformer(
281
+ mode=mode, patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
282
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
283
+ return model
284
+
285
+
286
+ def vit_base(mode='official', patch_size=16, **kwargs):
287
+ model = VisionTransformer(
288
+ mode=mode, patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
289
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
290
+ return model
291
+
292
+
293
+ class DINOHead(nn.Module):
294
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
295
+ super().__init__()
296
+ nlayers = max(nlayers, 1)
297
+ if nlayers == 1:
298
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
299
+ else:
300
+ layers = [nn.Linear(in_dim, hidden_dim)]
301
+ if use_bn:
302
+ layers.append(nn.BatchNorm1d(hidden_dim))
303
+ layers.append(nn.GELU())
304
+ for _ in range(nlayers - 2):
305
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
306
+ if use_bn:
307
+ layers.append(nn.BatchNorm1d(hidden_dim))
308
+ layers.append(nn.GELU())
309
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
310
+ self.mlp = nn.Sequential(*layers)
311
+ self.apply(self._init_weights)
312
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
313
+ self.last_layer.weight_g.data.fill_(1)
314
+ if norm_last_layer:
315
+ self.last_layer.weight_g.requires_grad = False
316
+
317
+ def _init_weights(self, m):
318
+ if isinstance(m, nn.Linear):
319
+ trunc_normal_(m.weight, std=.02)
320
+ if isinstance(m, nn.Linear) and m.bias is not None:
321
+ nn.init.constant_(m.bias, 0)
322
+
323
+ def forward(self, x):
324
+ x = self.mlp(x)
325
+ x = nn.functional.normalize(x, dim=-1, p=2)
326
+ x = self.last_layer(x)
327
+ return x