Paolo-Fraccaro commited on
Commit
b8e0a76
1 Parent(s): c70cc56

Update Prithvi.py

Browse files
Files changed (1) hide show
  1. Prithvi.py +31 -3
Prithvi.py CHANGED
@@ -15,12 +15,42 @@ import torch
15
  import torch.nn as nn
16
 
17
  from timm.models.vision_transformer import Block
18
- from timm.models.layers import to_2tuple, _assert
19
 
20
  import numpy as np
21
 
22
  from einops import rearrange
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
25
  """
26
  grid_size: 3d tuple of grid size: t, h, w
@@ -85,8 +115,6 @@ class PatchEmbed(nn.Module):
85
 
86
  def forward(self, x):
87
  B, C, T, H, W = x.shape
88
- _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
89
- _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
90
  x = self.proj(x)
91
  if self.flatten:
92
  x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
 
15
  import torch.nn as nn
16
 
17
  from timm.models.vision_transformer import Block
18
+ from timm.models.layers import to_2tuple
19
 
20
  import numpy as np
21
 
22
  from einops import rearrange
23
 
24
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
25
+ """
26
+ embed_dim: output dimension for each position
27
+ pos: a list of positions to be encoded: size (M,)
28
+ out: (M, D)
29
+ """
30
+ assert embed_dim % 2 == 0
31
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
32
+ omega /= embed_dim / 2.
33
+ omega = 1. / 10000**omega # (D/2,)
34
+
35
+ pos = pos.reshape(-1) # (M,)
36
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
37
+
38
+ emb_sin = np.sin(out) # (M, D/2)
39
+ emb_cos = np.cos(out) # (M, D/2)
40
+
41
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
42
+ return emb
43
+
44
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
45
+ assert embed_dim % 2 == 0
46
+
47
+ # use half of dimensions to encode grid_h
48
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
49
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
50
+
51
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
52
+ return emb
53
+
54
  def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
55
  """
56
  grid_size: 3d tuple of grid size: t, h, w
 
115
 
116
  def forward(self, x):
117
  B, C, T, H, W = x.shape
 
 
118
  x = self.proj(x)
119
  if self.flatten:
120
  x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C