Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import torch | |
| class FeatureExtractor(nn.Module): | |
| def __init__(self, patch_size=14, emb_dim=64): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.emb_dim = emb_dim | |
| self.proj = nn.Linear(patch_size * patch_size, emb_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| x: Tensor of shape (B, 1, 56, 56) | |
| returns patch_embeddings of shape (B, 16, emb_dim)""" | |
| B, C, H, W = x.shape | |
| patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) | |
| patches = patches.contiguous().view(B, -1, self.patch_size * self.patch_size) | |
| patch_embeddings = self.proj(patches) | |
| return patch_embeddings | |
| if __name__ == "__main__": | |
| feature_extractor = FeatureExtractor() | |
| dummy_input = torch.randn(8, 1, 56, 56) | |
| out = feature_extractor(dummy_input) | |
| print(out.shape) # should expect (8, 16, 64) |