transformer-mnist-demo / model /feature_extractor.py
nico-x's picture
codebase withouth model
b54146b
raw
history blame contribute delete
967 Bytes
import torch.nn as nn
import torch
class FeatureExtractor(nn.Module):
def __init__(self, patch_size=14, emb_dim=64):
super().__init__()
self.patch_size = patch_size
self.emb_dim = emb_dim
self.proj = nn.Linear(patch_size * patch_size, emb_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: Tensor of shape (B, 1, 56, 56)
returns patch_embeddings of shape (B, 16, emb_dim)"""
B, C, H, W = x.shape
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
patches = patches.contiguous().view(B, -1, self.patch_size * self.patch_size)
patch_embeddings = self.proj(patches)
return patch_embeddings
if __name__ == "__main__":
feature_extractor = FeatureExtractor()
dummy_input = torch.randn(8, 1, 56, 56)
out = feature_extractor(dummy_input)
print(out.shape) # should expect (8, 16, 64)