rameye commited on
Commit
561fb21
1 Parent(s): 294eb66

Update pytorch_model.bin

Browse files
Files changed (1) hide show
  1. pytorch_model.bin +54 -3
pytorch_model.bin CHANGED
@@ -1,4 +1,55 @@
1
- from transformers import ViTForImageClassification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- # Assuming you're loading the model using the ViTForImageClassification class
4
- model = ViTForImageClassification.from_pretrained('/home/user/.cache/huggingface/hub/models--rameye--1/snapshots/c2c5c38a641e6c048f33c2db25afa765727a04ed', from_tf=True)
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ # Partly revised by YZ @UCL&Moorfields
4
+ # --------------------------------------------------------
5
+
6
+ from functools import partial
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ import timm.models.vision_transformer
12
+
13
+
14
+ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
15
+ """ Vision Transformer with support for global average pooling
16
+ """
17
+ def __init__(self, global_pool=False, **kwargs):
18
+ super(VisionTransformer, self).__init__(**kwargs)
19
+
20
+ self.global_pool = global_pool
21
+ if self.global_pool:
22
+ norm_layer = kwargs['norm_layer']
23
+ embed_dim = kwargs['embed_dim']
24
+ self.fc_norm = norm_layer(embed_dim)
25
+
26
+ del self.norm # remove the original norm
27
+
28
+ def forward_features(self, x):
29
+ B = x.shape[0]
30
+ x = self.patch_embed(x)
31
+
32
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
33
+ x = torch.cat((cls_tokens, x), dim=1)
34
+ x = x + self.pos_embed
35
+ x = self.pos_drop(x)
36
+
37
+ for blk in self.blocks:
38
+ x = blk(x)
39
+
40
+ if self.global_pool:
41
+ x = x[:, 1:, :].mean(dim=1) # global pool without cls token
42
+ outcome = self.fc_norm(x)
43
+ else:
44
+ x = self.norm(x)
45
+ outcome = x[:, 0]
46
+
47
+ return outcome
48
+
49
+
50
+ def vit_large_patch16(**kwargs):
51
+ model = VisionTransformer(
52
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
53
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
54
+ return model
55