File size: 967 Bytes
b54146b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch.nn as nn
import torch

class FeatureExtractor(nn.Module):
    def __init__(self, patch_size=14, emb_dim=64):
        super().__init__()
        self.patch_size = patch_size
        self.emb_dim = emb_dim
        self.proj = nn.Linear(patch_size * patch_size, emb_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: Tensor of shape (B, 1, 56, 56)
        returns patch_embeddings of shape (B, 16, emb_dim)"""

        B, C, H, W = x.shape 
        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(B, -1, self.patch_size * self.patch_size)
        patch_embeddings = self.proj(patches)

        return patch_embeddings



if __name__ == "__main__":

    feature_extractor = FeatureExtractor()
    dummy_input = torch.randn(8, 1, 56, 56)
    out = feature_extractor(dummy_input)

    print(out.shape) # should expect (8, 16, 64)