WishArdently commited on
Commit
10a0e43
1 Parent(s): ecf0260

Upload V_JEPA

Browse files
Files changed (9) hide show
  1. config.json +0 -1
  2. model.py +37 -10
  3. model.safetensors +1 -1
  4. modules.py +183 -0
  5. patch_embed.py +57 -0
  6. pos_embs.py +99 -0
  7. tensors.py +71 -0
  8. utils.py +23 -0
  9. vision_transformer.py +324 -0
config.json CHANGED
@@ -7,7 +7,6 @@
7
  "AutoModel": "model.V_JEPA"
8
  },
9
  "ckpt_path": "/home/linanxi/V-JEPA/ckpt/vitl16.pth.tar",
10
- "device": "cuda:2",
11
  "model_type": "v-jepa",
12
  "torch_dtype": "float32",
13
  "transformers_version": "4.47.0",
 
7
  "AutoModel": "model.V_JEPA"
8
  },
9
  "ckpt_path": "/home/linanxi/V-JEPA/ckpt/vitl16.pth.tar",
 
10
  "model_type": "v-jepa",
11
  "torch_dtype": "float32",
12
  "transformers_version": "4.47.0",
model.py CHANGED
@@ -1,10 +1,10 @@
 
1
  from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
2
  from collections import OrderedDict
3
- from vision_transformer import vit_huge_16, vit_large_16
 
4
  import torch
5
 
6
- device = 'cuda:2'
7
-
8
  class JEPAConfig(PretrainedConfig):
9
  model_type = "v-jepa"
10
 
@@ -12,13 +12,12 @@ class JEPAConfig(PretrainedConfig):
12
  self,
13
  vit_type: str = 'vit_large_16',
14
  ckpt_path: str = None,
15
- device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
16
  **kwargs
17
  ):
18
  super().__init__(**kwargs)
19
  self.vit_type = vit_type
20
  self.ckpt_path = ckpt_path
21
- self.device = device
22
 
23
 
24
  class V_JEPA(PreTrainedModel):
@@ -28,9 +27,9 @@ class V_JEPA(PreTrainedModel):
28
  super().__init__(config)
29
  self.config = config
30
  if config.vit_type == 'vit_large_16':
31
- self.model = vit_large_16().to(config.device)
32
  elif config.vit_type == 'vit_huge_16':
33
- self.model = vit_huge_16().to(config.device)
34
  else:
35
  raise ValueError(f"Unsupported vit_type: {config.vit_type}")
36
 
@@ -39,14 +38,42 @@ class V_JEPA(PreTrainedModel):
39
 
40
  def load_checkpoint(self, ckpt_path):
41
  state_dict = OrderedDict()
42
- ckpt = torch.load(ckpt_path, weights_only=False, map_location=self.config.device)['encoder']
43
  for k, v in ckpt.items():
44
  new_key = k.split('.', 1)[-1]
45
  state_dict[new_key] = v
46
  self.model.load_state_dict(state_dict, strict=False)
47
  print("Checkpoint loaded successfully")
48
 
49
- def forward(self, x):
50
- return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
 
 
 
 
 
 
 
1
+ from unittest.mock import Base
2
  from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
3
  from collections import OrderedDict
4
+ from transformers.modeling_outputs import BaseModelOutput
5
+ from .vision_transformer import vit_huge_16, vit_large_16
6
  import torch
7
 
 
 
8
  class JEPAConfig(PretrainedConfig):
9
  model_type = "v-jepa"
10
 
 
12
  self,
13
  vit_type: str = 'vit_large_16',
14
  ckpt_path: str = None,
 
15
  **kwargs
16
  ):
17
  super().__init__(**kwargs)
18
  self.vit_type = vit_type
19
  self.ckpt_path = ckpt_path
20
+ # print(f'V-JEPA Config: {self.vit_type}, {self.ckpt_path}, {self.device}')
21
 
22
 
23
  class V_JEPA(PreTrainedModel):
 
27
  super().__init__(config)
28
  self.config = config
29
  if config.vit_type == 'vit_large_16':
30
+ self.model = vit_large_16()
31
  elif config.vit_type == 'vit_huge_16':
32
+ self.model = vit_huge_16()
33
  else:
34
  raise ValueError(f"Unsupported vit_type: {config.vit_type}")
35
 
 
38
 
39
  def load_checkpoint(self, ckpt_path):
40
  state_dict = OrderedDict()
41
+ ckpt = torch.load(ckpt_path, weights_only=False, map_location='cpu')['encoder']
42
  for k, v in ckpt.items():
43
  new_key = k.split('.', 1)[-1]
44
  state_dict[new_key] = v
45
  self.model.load_state_dict(state_dict, strict=False)
46
  print("Checkpoint loaded successfully")
47
 
48
+ def forward(self, x: torch.tensor):
49
+ """forward pass
50
+ Args:
51
+ x (torch.tensor): Shape (B, N, C, H, W) or (N, C, H, W)
52
+ Returns:
53
+ torch.tensor: Shape (B, N, hidden_size)
54
+ """
55
+ # if len(x.shape) == 5 and x.shape[1] == 9:
56
+ # x_8T = x[:, :8, :, :, :]
57
+ # x_1T = x[:, 8, :, :, :].unsqueeze(1)
58
+ # y_8T = self.forward(x_8T).last_hidden_state
59
+ # y_1T = self.forward(x_1T).last_hidden_state
60
+ # output = torch.cat((y_8T, y_1T), dim=1)
61
+ # return BaseModelOutput(last_hidden_state=output)
62
+
63
+ if len(x.shape) == 4:
64
+ x = x.unsqueeze(0)
65
+ B, N, C, H, W = x.shape
66
+ x = x.permute(0, 2, 1, 3, 4) # Shape(B, C, N, H, W)
67
+ output = self.model(x)
68
+ output = output.view(B, N, 98, -1) # Shape(B*N, 98, hidden_size)
69
+ output = output.mean(dim=2) # Shape(B*N, hidden_size)
70
+ # print("output shape: ", output.shape)
71
+ return BaseModelOutput(last_hidden_state=output)
72
 
73
 
74
+ if __name__ == "__main__":
75
+ config = JEPAConfig()
76
+ model = V_JEPA(config)
77
+ x = torch.randn(2, 8, 3, 224, 224)
78
+ output = model(x)
79
+ print(output.shape) # (16, 1024)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c81ae01decbe7a322042b7393e75397ff795e2edd9547a8847c025e4c259f9f5
3
  size 818904440
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88ebddcc1c98353e2c8a6c4a55800a9c36e6bd609134c446acf06a2b433e074b
3
  size 818904440
modules.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class MLP(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features,
17
+ hidden_features=None,
18
+ out_features=None,
19
+ act_layer=nn.GELU,
20
+ drop=0.
21
+ ):
22
+ super().__init__()
23
+ out_features = out_features or in_features
24
+ hidden_features = hidden_features or in_features
25
+ self.fc1 = nn.Linear(in_features, hidden_features)
26
+ self.act = act_layer()
27
+ self.fc2 = nn.Linear(hidden_features, out_features)
28
+ self.drop = nn.Dropout(drop)
29
+
30
+ def forward(self, x):
31
+ x = self.fc1(x)
32
+ x = self.act(x)
33
+ x = self.drop(x)
34
+ x = self.fc2(x)
35
+ x = self.drop(x)
36
+ return x
37
+
38
+
39
+ class Attention(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim,
43
+ num_heads=8,
44
+ qkv_bias=False,
45
+ qk_scale=None,
46
+ attn_drop=0.,
47
+ proj_drop=0.,
48
+ use_sdpa=True
49
+ ):
50
+ super().__init__()
51
+ self.num_heads = num_heads
52
+ head_dim = dim // num_heads
53
+ self.scale = qk_scale or head_dim ** -0.5
54
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
55
+ self.attn_drop = nn.Dropout(attn_drop)
56
+ self.proj = nn.Linear(dim, dim)
57
+ self.proj_drop_prob = proj_drop
58
+ self.proj_drop = nn.Dropout(proj_drop)
59
+ self.use_sdpa = use_sdpa
60
+
61
+ def forward(self, x, mask=None):
62
+ B, N, C = x.shape
63
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
64
+ q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D]
65
+
66
+ if self.use_sdpa:
67
+ with torch.backends.cuda.sdp_kernel():
68
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob)
69
+ attn = None
70
+ else:
71
+ attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D]
72
+ attn = attn.softmax(dim=-1)
73
+ attn = self.attn_drop(attn)
74
+ x = (attn @ v)
75
+ x = x.transpose(1, 2).reshape(B, N, C)
76
+ x = self.proj(x)
77
+ x = self.proj_drop(x)
78
+ return x, attn
79
+
80
+
81
+ class Block(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim,
85
+ num_heads,
86
+ mlp_ratio=4.,
87
+ qkv_bias=False,
88
+ qk_scale=None,
89
+ drop=0.,
90
+ attn_drop=0.,
91
+ act_layer=nn.GELU,
92
+ norm_layer=nn.LayerNorm,
93
+ grid_size=None,
94
+ grid_depth=None,
95
+ ):
96
+ super().__init__()
97
+ self.norm1 = norm_layer(dim)
98
+ self.attn = Attention(
99
+ dim,
100
+ num_heads=num_heads,
101
+ qkv_bias=qkv_bias,
102
+ qk_scale=qk_scale,
103
+ attn_drop=attn_drop,
104
+ proj_drop=drop)
105
+
106
+ self.norm2 = norm_layer(dim)
107
+ mlp_hidden_dim = int(dim * mlp_ratio)
108
+ self.mlp = MLP(
109
+ in_features=dim,
110
+ hidden_features=mlp_hidden_dim,
111
+ act_layer=act_layer,
112
+ drop=drop)
113
+
114
+ def forward(self, x, return_attention=False, mask=None):
115
+ y, attn = self.attn(self.norm1(x), mask=mask)
116
+ if return_attention:
117
+ return attn
118
+ x = x + y
119
+ x = x + self.mlp(self.norm2(x))
120
+ return x
121
+
122
+
123
+ class CrossAttention(nn.Module):
124
+ def __init__(
125
+ self,
126
+ dim,
127
+ num_heads=12,
128
+ qkv_bias=False,
129
+ use_sdpa=True
130
+ ):
131
+ super().__init__()
132
+ self.num_heads = num_heads
133
+ head_dim = dim // num_heads
134
+ self.scale = head_dim ** -0.5
135
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
136
+ self.kv = nn.Linear(dim, int(dim*2), bias=qkv_bias)
137
+ self.proj = nn.Linear(dim, dim)
138
+ self.use_sdpa = use_sdpa
139
+
140
+ def forward(self, q, x):
141
+ B, n, C = q.shape
142
+ q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
143
+
144
+ B, N, C = x.shape
145
+ kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
146
+ k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head)
147
+
148
+ if self.use_sdpa:
149
+ with torch.backends.cuda.sdp_kernel():
150
+ q = F.scaled_dot_product_attention(q, k, v)
151
+ else:
152
+ xattn = (q @ k.transpose(-2, -1)) * self.scale
153
+ xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len)
154
+ q = (xattn @ v)
155
+
156
+ q = q.transpose(1, 2).reshape(B, n, C)
157
+ q = self.proj(q)
158
+
159
+ return q
160
+
161
+
162
+ class CrossAttentionBlock(nn.Module):
163
+ def __init__(
164
+ self,
165
+ dim,
166
+ num_heads,
167
+ mlp_ratio=4.,
168
+ qkv_bias=False,
169
+ act_layer=nn.GELU,
170
+ norm_layer=nn.LayerNorm
171
+ ):
172
+ super().__init__()
173
+ self.norm1 = norm_layer(dim)
174
+ self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
175
+ self.norm2 = norm_layer(dim)
176
+ mlp_hidden_dim = int(dim * mlp_ratio)
177
+ self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
178
+
179
+ def forward(self, q, x):
180
+ y = self.xattn(q, self.norm1(x))
181
+ q = q + y
182
+ q = q + self.mlp(self.norm2(q))
183
+ return q
patch_embed.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch.nn as nn
9
+
10
+
11
+ class PatchEmbed(nn.Module):
12
+ """
13
+ Image to Patch Embedding
14
+ """
15
+ def __init__(
16
+ self,
17
+ patch_size=16,
18
+ in_chans=3,
19
+ embed_dim=768
20
+ ):
21
+ super().__init__()
22
+ self.patch_size = patch_size
23
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
24
+
25
+ def forward(self, x):
26
+ B, C, H, W = x.shape
27
+ x = self.proj(x).flatten(2).transpose(1, 2)
28
+ return x
29
+
30
+
31
+ class PatchEmbed3D(nn.Module):
32
+ """
33
+ Image to Patch Embedding
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ patch_size=16,
39
+ tubelet_size=2,
40
+ in_chans=3,
41
+ embed_dim=768,
42
+ ):
43
+ super().__init__()
44
+ self.patch_size = patch_size
45
+ self.tubelet_size = tubelet_size
46
+
47
+ self.proj = nn.Conv3d(
48
+ in_channels=in_chans,
49
+ out_channels=embed_dim,
50
+ kernel_size=(tubelet_size, patch_size, patch_size),
51
+ stride=(tubelet_size, patch_size, patch_size),
52
+ )
53
+
54
+ def forward(self, x, **kwargs):
55
+ B, C, T, H, W = x.shape
56
+ x = self.proj(x).flatten(2).transpose(1, 2)
57
+ return x
pos_embs.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import numpy as np
9
+
10
+
11
+ def get_3d_sincos_pos_embed(
12
+ embed_dim,
13
+ grid_size,
14
+ grid_depth,
15
+ cls_token=False,
16
+ uniform_power=False
17
+ ):
18
+ """
19
+ grid_size: int of the grid height and width
20
+ grid_depth: int of the grid depth
21
+ returns:
22
+ pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token)
23
+ or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token)
24
+ """
25
+ grid_d = np.arange(grid_depth, dtype=float)
26
+ grid_h = np.arange(grid_size, dtype=float)
27
+ grid_w = np.arange(grid_size, dtype=float)
28
+ grid_h, grid_d, grid_w = np.meshgrid(grid_h, grid_d, grid_w) # order of meshgrid is very important for indexing as [d,h,w]
29
+
30
+ if not uniform_power:
31
+ h_embed_dim = embed_dim // 4
32
+ w_embed_dim = embed_dim // 4
33
+ d_embed_dim = embed_dim // 2
34
+ else:
35
+ h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim/6)*2)
36
+
37
+ emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1)
38
+ emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2)
39
+ emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3)
40
+ pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1)
41
+ pos_embed = pos_embed[:, :embed_dim]
42
+ if cls_token:
43
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
44
+ return pos_embed
45
+
46
+
47
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
48
+ """
49
+ grid_size: int of the grid height and width
50
+ returns:
51
+ pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token)
52
+ or [1+grid_size*grid_size, embed_dim] (w/ cls_token)
53
+ """
54
+ grid_h = np.arange(grid_size, dtype=float)
55
+ grid_w = np.arange(grid_size, dtype=float)
56
+ grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w]
57
+
58
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2)
59
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2)
60
+ pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
61
+ if cls_token:
62
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
63
+ return pos_embed
64
+
65
+
66
+ def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
67
+ """
68
+ embed_dim: output dimension for each position
69
+ grid_size: int of the grid length
70
+ returns:
71
+ pos_embed: [grid_size, embed_dim] (w/o cls_token)
72
+ or [1+grid_size, embed_dim] (w/ cls_token)
73
+ """
74
+ grid = np.arange(grid_size, dtype=float)
75
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
76
+ if cls_token:
77
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
78
+ return pos_embed
79
+
80
+
81
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
82
+ """
83
+ embed_dim: output dimension for each position
84
+ pos: a list of positions to be encoded: size (M,)
85
+ returns: (M, D)
86
+ """
87
+ assert embed_dim % 2 == 0
88
+ omega = np.arange(embed_dim // 2, dtype=float)
89
+ omega /= embed_dim / 2.
90
+ omega = 1. / 10000**omega # (D/2,)
91
+
92
+ pos = pos.reshape(-1) # (M,)
93
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
94
+
95
+ emb_sin = np.sin(out) # (M, D/2)
96
+ emb_cos = np.cos(out) # (M, D/2)
97
+
98
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
99
+ return emb
tensors.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import math
9
+
10
+ import torch
11
+
12
+ from logging import getLogger
13
+
14
+ logger = getLogger()
15
+
16
+
17
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
18
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
19
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
20
+ def norm_cdf(x):
21
+ # Computes standard normal cumulative distribution function
22
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
23
+
24
+ with torch.no_grad():
25
+ # Values are generated by using a truncated uniform distribution and
26
+ # then using the inverse CDF for the normal distribution.
27
+ # Get upper and lower cdf values
28
+ l = norm_cdf((a - mean) / std)
29
+ u = norm_cdf((b - mean) / std)
30
+
31
+ # Uniformly fill tensor with values from [l, u], then translate to
32
+ # [2l-1, 2u-1].
33
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
34
+
35
+ # Use inverse cdf transform for normal distribution to get truncated
36
+ # standard normal
37
+ tensor.erfinv_()
38
+
39
+ # Transform to proper mean, std
40
+ tensor.mul_(std * math.sqrt(2.))
41
+ tensor.add_(mean)
42
+
43
+ # Clamp to ensure it's in the proper range
44
+ tensor.clamp_(min=a, max=b)
45
+ return tensor
46
+
47
+
48
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
49
+ # type: (Tensor, float, float, float, float) -> Tensor
50
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
51
+
52
+
53
+ def apply_masks(x, masks):
54
+ """
55
+ :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
56
+ :param masks: list of tensors containing indices of patches [0,N) to keep
57
+ """
58
+ all_x = []
59
+ for m in masks:
60
+ mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
61
+ all_x += [torch.gather(x, dim=1, index=mask_keep)]
62
+ return torch.cat(all_x, dim=0)
63
+
64
+
65
+ def repeat_interleave_batch(x, B, repeat):
66
+ N = len(x) // B
67
+ x = torch.cat([
68
+ torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0)
69
+ for i in range(N)
70
+ ], dim=0)
71
+ return x
utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+
10
+
11
+ def apply_masks(x, masks, concat=True):
12
+ """
13
+ :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
14
+ :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep
15
+ """
16
+ all_x = []
17
+ for m in masks:
18
+ mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
19
+ all_x += [torch.gather(x, dim=1, index=mask_keep)]
20
+ if not concat:
21
+ return all_x
22
+
23
+ return torch.cat(all_x, dim=0)
vision_transformer.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .patch_embed import PatchEmbed, PatchEmbed3D
15
+ from .modules import Block
16
+ from .pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed
17
+ from .tensors import trunc_normal_
18
+ from .utils import apply_masks
19
+
20
+
21
+ class VisionTransformer(nn.Module):
22
+ """ Vision Transformer """
23
+ def __init__(
24
+ self,
25
+ img_size=224,
26
+ patch_size=16,
27
+ num_frames=1,
28
+ tubelet_size=2,
29
+ in_chans=3,
30
+ embed_dim=768,
31
+ depth=12,
32
+ num_heads=12,
33
+ mlp_ratio=4.0,
34
+ qkv_bias=True,
35
+ qk_scale=None,
36
+ drop_rate=0.0,
37
+ attn_drop_rate=0.0,
38
+ norm_layer=nn.LayerNorm,
39
+ init_std=0.02,
40
+ out_layers=None,
41
+ uniform_power=False,
42
+ **kwargs
43
+ ):
44
+ super().__init__()
45
+ self.num_features = self.embed_dim = embed_dim
46
+ self.num_heads = num_heads
47
+ self.out_layers = out_layers
48
+
49
+ self.input_size = img_size
50
+ self.patch_size = patch_size
51
+
52
+ self.num_frames = num_frames
53
+ self.tubelet_size = tubelet_size
54
+ self.is_video = num_frames > 1
55
+
56
+ grid_size = self.input_size // self.patch_size
57
+ grid_depth = self.num_frames // self.tubelet_size
58
+
59
+ # Tokenize pixels with convolution
60
+ if self.is_video:
61
+ self.patch_embed = PatchEmbed3D(
62
+ patch_size=patch_size,
63
+ tubelet_size=tubelet_size,
64
+ in_chans=in_chans,
65
+ embed_dim=embed_dim)
66
+ self.num_patches = (
67
+ (num_frames // tubelet_size)
68
+ * (img_size // patch_size)
69
+ * (img_size // patch_size)
70
+ )
71
+ else:
72
+ self.patch_embed = PatchEmbed(
73
+ patch_size=patch_size,
74
+ in_chans=in_chans,
75
+ embed_dim=embed_dim)
76
+ self.num_patches = (
77
+ (img_size // patch_size)
78
+ * (img_size // patch_size)
79
+ )
80
+
81
+ # Position embedding
82
+ self.uniform_power = uniform_power
83
+ self.pos_embed = None
84
+ self.pos_embed = nn.Parameter(
85
+ torch.zeros(1, self.num_patches, embed_dim),
86
+ requires_grad=False)
87
+
88
+ # Attention Blocks
89
+ self.blocks = nn.ModuleList([
90
+ Block(
91
+ dim=embed_dim,
92
+ num_heads=num_heads,
93
+ mlp_ratio=mlp_ratio,
94
+ qkv_bias=qkv_bias,
95
+ qk_scale=qk_scale,
96
+ drop=drop_rate,
97
+ act_layer=nn.GELU,
98
+ grid_size=grid_size,
99
+ grid_depth=grid_depth,
100
+ attn_drop=attn_drop_rate,
101
+ norm_layer=norm_layer)
102
+ for i in range(depth)])
103
+ self.norm = norm_layer(embed_dim)
104
+
105
+ # ------ initialize weights
106
+ if self.pos_embed is not None:
107
+ self._init_pos_embed(self.pos_embed.data) # sincos pos-embed
108
+ self.init_std = init_std
109
+ self.apply(self._init_weights)
110
+ self._rescale_blocks()
111
+
112
+ def _init_pos_embed(self, pos_embed):
113
+ embed_dim = pos_embed.size(-1)
114
+ grid_size = self.input_size // self.patch_size
115
+ if self.is_video:
116
+ grid_depth = self.num_frames // self.tubelet_size
117
+ sincos = get_3d_sincos_pos_embed(
118
+ embed_dim,
119
+ grid_size,
120
+ grid_depth,
121
+ cls_token=False,
122
+ uniform_power=self.uniform_power
123
+ )
124
+ else:
125
+ sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)
126
+ pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0))
127
+
128
+ def _init_weights(self, m):
129
+ if isinstance(m, nn.Linear):
130
+ trunc_normal_(m.weight, std=self.init_std)
131
+ if isinstance(m, nn.Linear) and m.bias is not None:
132
+ nn.init.constant_(m.bias, 0)
133
+ elif isinstance(m, nn.LayerNorm):
134
+ nn.init.constant_(m.bias, 0)
135
+ nn.init.constant_(m.weight, 1.0)
136
+ elif isinstance(m, nn.Conv2d):
137
+ trunc_normal_(m.weight, std=self.init_std)
138
+ if m.bias is not None:
139
+ nn.init.constant_(m.bias, 0)
140
+ elif isinstance(m, nn.Conv3d):
141
+ trunc_normal_(m.weight, std=self.init_std)
142
+ if m.bias is not None:
143
+ nn.init.constant_(m.bias, 0)
144
+
145
+ def _rescale_blocks(self):
146
+ def rescale(param, layer_id):
147
+ param.div_(math.sqrt(2.0 * layer_id))
148
+
149
+ for layer_id, layer in enumerate(self.blocks):
150
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
151
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
152
+
153
+ def get_num_layers(self):
154
+ return len(self.blocks)
155
+
156
+ def no_weight_decay(self):
157
+ return {}
158
+
159
+ def forward(self, x, masks=None):
160
+ """
161
+ :param x: input image/video
162
+ :param masks: indices of patch tokens to mask (remove)
163
+ """
164
+ if masks is not None and not isinstance(masks, list):
165
+ masks = [masks]
166
+
167
+ # Tokenize input
168
+ pos_embed = self.pos_embed
169
+ if pos_embed is not None:
170
+ pos_embed = self.interpolate_pos_encoding(x, pos_embed)
171
+ x = self.patch_embed(x)
172
+ if pos_embed is not None:
173
+ x += pos_embed
174
+ B, N, D = x.shape
175
+
176
+ # Mask away unwanted tokens (if masks provided)
177
+ if masks is not None:
178
+ x = apply_masks(x, masks)
179
+ masks = torch.cat(masks, dim=0)
180
+
181
+ # Fwd prop
182
+ outs = []
183
+ for i, blk in enumerate(self.blocks):
184
+ x = blk(x, mask=masks)
185
+ if self.out_layers is not None and i in self.out_layers:
186
+ outs.append(self.norm(x))
187
+
188
+ if self.out_layers is not None:
189
+ return outs
190
+
191
+ if self.norm is not None:
192
+ x = self.norm(x)
193
+
194
+ return x
195
+
196
+ def interpolate_pos_encoding(self, x, pos_embed):
197
+
198
+ _, N, dim = pos_embed.shape
199
+
200
+ if self.is_video:
201
+
202
+ # If pos_embed already corret size, just return
203
+ _, _, T, H, W = x.shape
204
+ if H == self.input_size and W == self.input_size and T == self.num_frames:
205
+ return pos_embed
206
+
207
+ # Convert depth, height, width of input to be measured in patches
208
+ # instead of pixels/frames
209
+ T = T // self.tubelet_size
210
+ H = H // self.patch_size
211
+ W = W // self.patch_size
212
+
213
+ # Compute the initialized shape of the positional embedding measured
214
+ # in patches
215
+ N_t = self.num_frames // self.tubelet_size
216
+ N_h = N_w = self.input_size // self.patch_size
217
+ assert N_h * N_w * N_t == N, 'Positional embedding initialized incorrectly'
218
+
219
+ # Compute scale factor for spatio-temporal interpolation
220
+ scale_factor = (T/N_t, H/N_h, W/N_w)
221
+
222
+ pos_embed = nn.functional.interpolate(
223
+ pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3),
224
+ scale_factor=scale_factor,
225
+ mode='trilinear')
226
+ pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim)
227
+ return pos_embed
228
+
229
+ else:
230
+
231
+ # If pos_embed already corret size, just return
232
+ _, _, H, W = x.shape
233
+ if H == self.input_size and W == self.input_size:
234
+ return pos_embed
235
+
236
+ # Compute scale factor for spatial interpolation
237
+ npatch = (H // self.patch_size) * (W // self.patch_size)
238
+ scale_factor = math.sqrt(npatch / N)
239
+
240
+ pos_embed = nn.functional.interpolate(
241
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
242
+ scale_factor=scale_factor,
243
+ mode='bicubic')
244
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
245
+ return pos_embed
246
+
247
+
248
+ def vit_tiny(patch_size=16, **kwargs):
249
+ model = VisionTransformer(
250
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
251
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
252
+ return model
253
+
254
+
255
+ def vit_small(patch_size=16, **kwargs):
256
+ model = VisionTransformer(
257
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
258
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
259
+ return model
260
+
261
+
262
+ def vit_base(patch_size=16, **kwargs):
263
+ model = VisionTransformer(
264
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
265
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
266
+ return model
267
+
268
+
269
+ def vit_large(patch_size=16, **kwargs):
270
+ model = VisionTransformer(
271
+ patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
272
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
273
+ return model
274
+
275
+ def vit_huge(patch_size=16, **kwargs):
276
+ model = VisionTransformer(
277
+ patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
278
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
279
+ return model
280
+
281
+
282
+ def vit_giant(patch_size=16, **kwargs):
283
+ model = VisionTransformer(
284
+ patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
285
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
286
+ return model
287
+
288
+
289
+ def vit_gigantic(patch_size=14, **kwargs):
290
+ model = VisionTransformer(
291
+ patch_size=patch_size, embed_dim=1664, depth=48, num_heads=16, mpl_ratio=64/13,
292
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
293
+ )
294
+ return model
295
+
296
+
297
+ VIT_EMBED_DIMS = {
298
+ 'vit_tiny': 192,
299
+ 'vit_small': 384,
300
+ 'vit_base': 768,
301
+ 'vit_large': 1024,
302
+ 'vit_huge': 1280,
303
+ 'vit_giant': 1408,
304
+ 'vit_gigantic': 1664,
305
+ }
306
+
307
+ ################
308
+ ### Video Encoders ###
309
+ def vit_large_16(patch_size=16, num_frames=16,**kwargs):
310
+ model = VisionTransformer(
311
+ patch_size=patch_size, embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4,
312
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), num_frames=num_frames,**kwargs)
313
+ return model
314
+
315
+ def vit_huge_16(patch_size=16, num_frames=16,**kwargs):
316
+ model = VisionTransformer(
317
+ patch_size=patch_size, embed_dim=1280, depth=16, num_heads=16, mlp_ratio=4,
318
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), num_frames=num_frames,**kwargs)
319
+ return model
320
+ ################
321
+
322
+ if __name__ == '__main__':
323
+ model = vit_large_16()
324
+ print('Right.')