Srikumar26 commited on
Commit
54bf4fb
1 Parent(s): 90dd44f

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +302 -0
model.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/bair-climate-initiative/scale-mae/blob/main/mae/main_finetune.py
2
+ import torch
3
+ from timm.models.layers import trunc_normal_
4
+ from functools import partial
5
+ import timm.models.vision_transformer
6
+ import torch.nn as nn
7
+ from timm.models.vision_transformer import Block, PatchEmbed
8
+ import os
9
+ from torchvision.io import read_image
10
+ import numpy as np
11
+ import sys
12
+ import random
13
+ import pytorch_lightning as pl
14
+ import torch.nn.functional as F
15
+ from huggingface_hub import PyTorchModelHubMixin
16
+
17
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
18
+ """
19
+ grid_size: int of the grid height and width
20
+ return:
21
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
22
+ """
23
+ grid_h = np.arange(grid_size, dtype=np.float32)
24
+ grid_w = np.arange(grid_size, dtype=np.float32)
25
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
26
+ grid = np.stack(grid, axis=0)
27
+
28
+ grid = grid.reshape([2, 1, grid_size, grid_size])
29
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
30
+ if cls_token:
31
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
32
+ return pos_embed
33
+
34
+
35
+ def get_2d_sincos_pos_embed_with_resolution(
36
+ embed_dim, grid_size, res, cls_token=False, device="cpu"
37
+ ):
38
+ """
39
+ grid_size: int of the grid height and width
40
+ res: array of size n, representing the resolution of a pixel (say, in meters),
41
+ return:
42
+ pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
43
+ """
44
+ # res = torch.FloatTensor(res).to(device)
45
+ res = res.to(device)
46
+ grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
47
+ grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
48
+ grid = torch.meshgrid(
49
+ grid_w, grid_h, indexing="xy"
50
+ ) # here h goes first,direction reversed for numpy
51
+ grid = torch.stack(grid, dim=0) # 2 x h x w
52
+
53
+ # grid = grid.reshape([2, 1, grid_size, grid_size])
54
+ grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w
55
+ _, n, h, w = grid.shape
56
+ pos_embed = get_2d_sincos_pos_embed_from_grid_torch(
57
+ embed_dim, grid
58
+ ) # # (nxH*W, D/2)
59
+ pos_embed = pos_embed.reshape(n, h * w, embed_dim)
60
+ if cls_token:
61
+ pos_embed = torch.cat(
62
+ [
63
+ torch.zeros(
64
+ [n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device
65
+ ),
66
+ pos_embed,
67
+ ],
68
+ dim=1,
69
+ )
70
+ return pos_embed
71
+
72
+
73
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
74
+ assert embed_dim % 2 == 0
75
+
76
+ # use half of dimensions to encode grid_h
77
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
78
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
79
+
80
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
81
+ return emb
82
+
83
+
84
+ def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
85
+ assert embed_dim % 2 == 0
86
+
87
+ # use half of dimensions to encode grid_h
88
+ emb_h = get_1d_sincos_pos_embed_from_grid_torch(
89
+ embed_dim // 2, grid[0]
90
+ ) # (H*W, D/2)
91
+ emb_w = get_1d_sincos_pos_embed_from_grid_torch(
92
+ embed_dim // 2, grid[1]
93
+ ) # (H*W, D/2)
94
+
95
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D)
96
+ return emb
97
+
98
+
99
+ def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
100
+ """
101
+ embed_dim: output dimension for each position
102
+ pos: a list of positions to be encoded: size (M,)
103
+ out: (M, D)
104
+ """
105
+ assert embed_dim % 2 == 0
106
+ old_shape = pos
107
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
108
+ omega /= embed_dim / 2.0
109
+ omega = 1.0 / 10000**omega # (D/2,)
110
+
111
+ pos = pos.reshape(-1) # (M,)
112
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
113
+
114
+ emb_sin = torch.sin(out) # (M, D/2)
115
+ emb_cos = torch.cos(out) # (M, D/2)
116
+
117
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
118
+ return emb
119
+
120
+
121
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
122
+ """
123
+ embed_dim: output dimension for each position
124
+ pos: a list of positions to be encoded: size (M,)
125
+ out: (M, D)
126
+ """
127
+ assert embed_dim % 2 == 0
128
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
129
+ omega /= embed_dim / 2.0
130
+ omega = 1.0 / 10000**omega # (D/2,)
131
+
132
+ pos = pos.reshape(-1) # (M,)
133
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
134
+
135
+ emb_sin = np.sin(out) # (M, D/2)
136
+ emb_cos = np.cos(out) # (M, D/2)
137
+
138
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
139
+ return emb
140
+
141
+
142
+ # --------------------------------------------------------
143
+ # Interpolate position embeddings for high-resolution
144
+ # References:
145
+ # DeiT: https://github.com/facebookresearch/deit
146
+ # --------------------------------------------------------
147
+ def interpolate_pos_embed(model, checkpoint_model):
148
+ if "pos_embed" in checkpoint_model:
149
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
150
+ embedding_size = pos_embed_checkpoint.shape[-1]
151
+ num_patches = model.patch_embed.num_patches
152
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
153
+ # height (== width) for the checkpoint position embedding
154
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
155
+ # height (== width) for the new position embedding
156
+ new_size = int(num_patches**0.5)
157
+ # class_token and dist_token are kept unchanged
158
+ if orig_size != new_size:
159
+ print(
160
+ "Position interpolate from %dx%d to %dx%d"
161
+ % (orig_size, orig_size, new_size, new_size)
162
+ )
163
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
164
+ # only the position tokens are interpolated
165
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
166
+ pos_tokens = pos_tokens.reshape(
167
+ -1, orig_size, orig_size, embedding_size
168
+ ).permute(0, 3, 1, 2)
169
+ pos_tokens = torch.nn.functional.interpolate(
170
+ pos_tokens,
171
+ size=(new_size, new_size),
172
+ mode="bicubic",
173
+ align_corners=False,
174
+ )
175
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
176
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
177
+ checkpoint_model["pos_embed"] = new_pos_embed
178
+
179
+ class PatchEmbedUnSafe(PatchEmbed):
180
+ """Image to Patch Embedding"""
181
+
182
+ def forward(self, x):
183
+ B, C, H, W = x.shape
184
+ # Dropped size check in timm
185
+ # assert H == self.img_size[0] and W == self.img_size[1], \
186
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
187
+ x = self.proj(x).flatten(2).transpose(1, 2)
188
+ return x
189
+
190
+
191
+ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
192
+ """Vision Transformer with support for global average pooling"""
193
+
194
+ def __init__(
195
+ self, cls_token_flag=False, global_pool=False, patch_size=16, in_chans=3, embed_dim=1024, **kwargs
196
+ ):
197
+ super().__init__(embed_dim=embed_dim, **kwargs)
198
+ self.cls_token_flag = cls_token_flag
199
+
200
+ self.patch_embed = PatchEmbedUnSafe(
201
+ img_size=224,
202
+ patch_size=patch_size,
203
+ in_chans=in_chans,
204
+ embed_dim=embed_dim,
205
+ )
206
+
207
+ self.global_pool = global_pool
208
+ if self.global_pool:
209
+ norm_layer = kwargs["norm_layer"]
210
+ embed_dim = embed_dim
211
+ self.fc_norm = norm_layer(embed_dim)
212
+
213
+ del self.norm # remove the original norm
214
+
215
+ del self.head
216
+ if self.cls_token_flag == False:
217
+ del self.cls_token
218
+ del self.pos_embed
219
+
220
+ def forward_features(self, x, input_res=None):
221
+ B, _, h, w = x.shape
222
+ x = self.patch_embed(x)
223
+ input_res = input_res.cpu()
224
+
225
+ num_patches = int(
226
+ (h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1])
227
+ )
228
+ pos_embed = get_2d_sincos_pos_embed_with_resolution(
229
+ x.shape[-1],
230
+ int(num_patches**0.5),
231
+ input_res,
232
+ cls_token=self.cls_token_flag,
233
+ device=x.device,
234
+ )
235
+
236
+ if self.cls_token_flag:
237
+ cls_tokens = self.cls_token.expand(
238
+ B, -1, -1
239
+ ) # stole cls_tokens impl from Phil Wang, thanks
240
+ x = torch.cat((cls_tokens, x), dim=1)
241
+ x = x + pos_embed
242
+ x = self.pos_drop(x)
243
+
244
+ for blk in self.blocks:
245
+ x = blk(x)
246
+
247
+ #x = x[:, 1:, :].mean(dim=1) # global pool without cls token
248
+
249
+ outcome = self.fc_norm(x)
250
+ return outcome
251
+
252
+ def forward(self, x, input_res=None):
253
+ x = self.forward_features(x, input_res=input_res)
254
+ return x
255
+
256
+
257
+ def vit_large_patch16(**kwargs):
258
+ model = VisionTransformer(
259
+ patch_size=16,
260
+ embed_dim=1024,
261
+ depth=24,
262
+ num_heads=16,
263
+ mlp_ratio=4,
264
+ qkv_bias=True,
265
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
266
+ **kwargs
267
+ )
268
+ return model
269
+
270
+ def get_ScaleMAE_model(global_pool=True, cls_token=True):
271
+
272
+ model = vit_large_patch16(
273
+ num_classes=1000,
274
+ drop_path_rate=0.1,
275
+ global_pool=global_pool,
276
+ cls_token_flag = cls_token
277
+ )
278
+
279
+ if global_pool:
280
+ assert set(msg.missing_keys) == {
281
+ "head.weight",
282
+ "head.bias",
283
+ "fc_norm.weight",
284
+ "fc_norm.bias",
285
+ }
286
+ else:
287
+ pass
288
+
289
+ return model
290
+
291
+
292
+ class ScaleMAE_baseline(pl.LightningModule, PyTorchModelHubMixin):
293
+ def __init__(self, feat_dim=1024, fc_dim=1024, global_pool=False, cls_token_flag=True):
294
+ super().__init__()
295
+ self.model = get_ScaleMAE_model(global_pool= global_pool,cls_token = cls_token_flag)
296
+
297
+ def forward(self,x,patch_size,input_res=10.0):
298
+
299
+ input_res = torch.tensor([10.0]).to(x.device)
300
+ x = self.model(x,input_res=input_res)
301
+
302
+ return x