Saghir commited on
Commit
be806fa
1 Parent(s): be2c585

Uploaded files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/ActivationMap.png filter=lfs diff=lfs merge=lfs -text
PathDino.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Mostly copy-paste from timm library.
16
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17
+ """
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torchvision import transforms
24
+
25
+ # from models.dino.utils import trunc_normal_
26
+
27
+
28
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
29
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
30
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
31
+ def norm_cdf(x):
32
+ # Computes standard normal cumulative distribution function
33
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
34
+
35
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
36
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
37
+ "The distribution of values may be incorrect.",
38
+ stacklevel=2)
39
+
40
+ with torch.no_grad():
41
+ # Values are generated by using a truncated uniform distribution and
42
+ # then using the inverse CDF for the normal distribution.
43
+ # Get upper and lower cdf values
44
+ l = norm_cdf((a - mean) / std)
45
+ u = norm_cdf((b - mean) / std)
46
+
47
+ # Uniformly fill tensor with values from [l, u], then translate to
48
+ # [2l-1, 2u-1].
49
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
50
+
51
+ # Use inverse cdf transform for normal distribution to get truncated
52
+ # standard normal
53
+ tensor.erfinv_()
54
+
55
+ # Transform to proper mean, std
56
+ tensor.mul_(std * math.sqrt(2.))
57
+ tensor.add_(mean)
58
+
59
+ # Clamp to ensure it's in the proper range
60
+ tensor.clamp_(min=a, max=b)
61
+ return tensor
62
+
63
+
64
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
65
+ # type: (Tensor, float, float, float, float) -> Tensor
66
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
67
+
68
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
69
+ if drop_prob == 0. or not training:
70
+ return x
71
+ keep_prob = 1 - drop_prob
72
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
73
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
74
+ random_tensor.floor_() # binarize
75
+ output = x.div(keep_prob) * random_tensor
76
+ return output
77
+
78
+
79
+ class DropPath(nn.Module):
80
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
81
+ """
82
+ def __init__(self, drop_prob=None):
83
+ super(DropPath, self).__init__()
84
+ self.drop_prob = drop_prob
85
+
86
+ def forward(self, x):
87
+ return drop_path(x, self.drop_prob, self.training)
88
+
89
+
90
+ class Mlp(nn.Module):
91
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
92
+ super().__init__()
93
+ out_features = out_features or in_features
94
+ hidden_features = hidden_features or in_features
95
+ self.fc1 = nn.Linear(in_features, hidden_features)
96
+ self.act = act_layer()
97
+ self.fc2 = nn.Linear(hidden_features, out_features)
98
+ self.drop = nn.Dropout(drop)
99
+
100
+ def forward(self, x):
101
+ x = self.fc1(x)
102
+ x = self.act(x)
103
+ x = self.drop(x)
104
+ x = self.fc2(x)
105
+ x = self.drop(x)
106
+ return x
107
+
108
+
109
+ class Attention(nn.Module):
110
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
111
+ super().__init__()
112
+ self.num_heads = num_heads
113
+ head_dim = dim // num_heads
114
+ self.scale = qk_scale or head_dim ** -0.5
115
+
116
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
117
+ self.attn_drop = nn.Dropout(attn_drop)
118
+ self.proj = nn.Linear(dim, dim)
119
+ self.proj_drop = nn.Dropout(proj_drop)
120
+
121
+ def forward(self, x):
122
+ B, N, C = x.shape
123
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124
+ q, k, v = qkv[0], qkv[1], qkv[2]
125
+
126
+ attn = (q @ k.transpose(-2, -1)) * self.scale
127
+ attn = attn.softmax(dim=-1)
128
+ attn = self.attn_drop(attn)
129
+
130
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
131
+ x = self.proj(x)
132
+ x = self.proj_drop(x)
133
+ return x, attn
134
+
135
+
136
+ class Block(nn.Module):
137
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
138
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
139
+ super().__init__()
140
+ self.norm1 = norm_layer(dim)
141
+ self.attn = Attention(
142
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
143
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
144
+ self.norm2 = norm_layer(dim)
145
+ mlp_hidden_dim = int(dim * mlp_ratio)
146
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
147
+
148
+ def forward(self, x, return_attention=False):
149
+ y, attn = self.attn(self.norm1(x))
150
+ if return_attention:
151
+ return attn
152
+ x = x + self.drop_path(y)
153
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
154
+ return x
155
+
156
+
157
+ class PatchEmbed(nn.Module):
158
+ """ Image to Patch Embedding
159
+ """
160
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
161
+ super().__init__()
162
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
163
+ self.img_size = img_size
164
+ self.patch_size = patch_size
165
+ self.num_patches = num_patches
166
+
167
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
168
+
169
+ def forward(self, x):
170
+ B, C, H, W = x.shape
171
+ x = self.proj(x).flatten(2).transpose(1, 2)
172
+ return x
173
+
174
+
175
+ class VisionTransformer(nn.Module):
176
+ """ Vision Transformer """
177
+ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
178
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
179
+ drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
180
+ super().__init__()
181
+ self.num_features = self.embed_dim = embed_dim
182
+
183
+ self.patch_embed = PatchEmbed(
184
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
185
+ num_patches = self.patch_embed.num_patches
186
+
187
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
188
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
189
+ self.pos_drop = nn.Dropout(p=drop_rate)
190
+
191
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
192
+ self.blocks = nn.ModuleList([
193
+ Block(
194
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
195
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
196
+ for i in range(depth)])
197
+ self.norm = norm_layer(embed_dim)
198
+
199
+ # Classifier head
200
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
201
+
202
+ trunc_normal_(self.pos_embed, std=.02)
203
+ trunc_normal_(self.cls_token, std=.02)
204
+ self.apply(self._init_weights)
205
+
206
+ def _init_weights(self, m):
207
+ if isinstance(m, nn.Linear):
208
+ trunc_normal_(m.weight, std=.02)
209
+ if isinstance(m, nn.Linear) and m.bias is not None:
210
+ nn.init.constant_(m.bias, 0)
211
+ elif isinstance(m, nn.LayerNorm):
212
+ nn.init.constant_(m.bias, 0)
213
+ nn.init.constant_(m.weight, 1.0)
214
+
215
+ def interpolate_pos_encoding(self, x, w, h):
216
+ npatch = x.shape[1] - 1
217
+ N = self.pos_embed.shape[1] - 1
218
+ if npatch == N and w == h:
219
+ return self.pos_embed
220
+ class_pos_embed = self.pos_embed[:, 0]
221
+ patch_pos_embed = self.pos_embed[:, 1:]
222
+ dim = x.shape[-1]
223
+ w0 = w // self.patch_embed.patch_size
224
+ h0 = h // self.patch_embed.patch_size
225
+ # we add a small number to avoid floating point error in the interpolation
226
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
227
+ w0, h0 = w0 + 0.1, h0 + 0.1
228
+ patch_pos_embed = nn.functional.interpolate(
229
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
230
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
231
+ mode='bicubic',
232
+ )
233
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
234
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
235
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
236
+
237
+ def prepare_tokens(self, x):
238
+ print(x.shape)
239
+ B, nc, w, h = x.shape
240
+ x = self.patch_embed(x) # patch linear embedding
241
+
242
+ # add the [CLS] token to the embed patch tokens
243
+ cls_tokens = self.cls_token.expand(B, -1, -1)
244
+ x = torch.cat((cls_tokens, x), dim=1)
245
+
246
+ # add positional encoding to each token
247
+ x = x + self.interpolate_pos_encoding(x, w, h)
248
+
249
+ return self.pos_drop(x)
250
+
251
+ def forward(self, x):
252
+ print(x.shape)
253
+ x = self.prepare_tokens(x)
254
+ for blk in self.blocks:
255
+ x = blk(x)
256
+ x = self.norm(x)
257
+ return x[:, 0]
258
+
259
+ def get_last_selfattention(self, x):
260
+ x = self.prepare_tokens(x)
261
+ for i, blk in enumerate(self.blocks):
262
+ if i < len(self.blocks) - 1:
263
+ x = blk(x)
264
+ else:
265
+ # return attention of the last block
266
+ return blk(x, return_attention=True)
267
+
268
+ def get_intermediate_layers(self, x, n=1):
269
+ x = self.prepare_tokens(x)
270
+ # we return the output tokens from the `n` last blocks
271
+ output = []
272
+ for i, blk in enumerate(self.blocks):
273
+ x = blk(x)
274
+ if len(self.blocks) - i <= n:
275
+ output.append(self.norm(x))
276
+ return output
277
+
278
+
279
+ def get_pathDino_model(weights_path="PathDino512.pth", **kwargs):
280
+
281
+ model = VisionTransformer(img_size=[512], patch_size=16, embed_dim=384, depth=5, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
282
+ for p in model.parameters():
283
+ p.requires_grad = False
284
+ model.eval()
285
+ # model.to(device)
286
+ state_dict = torch.load(weights_path, map_location="cpu")
287
+ # remove `backbone.` prefix induced by multicrop wrapper
288
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
289
+ # model.load_state_dict(state_dict, strict=False)
290
+ model.load_state_dict(state_dict)
291
+
292
+ data_transforms_PathDino = transforms.Compose([
293
+ transforms.Resize(512),
294
+ transforms.CenterCrop(512),
295
+ transforms.ToTensor(),
296
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
297
+ ])
298
+
299
+ return model, data_transforms_PathDino
PathDino512.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80e4ebeaf18762d9b3fa281d6087f752f2af15ce5d3a007fdb195f8872a9e941
3
+ size 38273943
images/ActivationMap.png ADDED

Git LFS Details

  • SHA256: e8f0d3594d4e5530e8229fb3bcf9ecf1688028c816f9bae2eeaab8dd31c4b6e4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.27 MB
images/FigPathDino_parameters_FLOPs_compare.png ADDED
images/HistRotate.png ADDED
requirements.txt ADDED
File without changes