MAE-Demo / model.py
turhancan97's picture
app file created
929f451
raw
history blame
8.76 kB
# --------------------------------------------------------
# References:
# MAE: https://github.com/IcarusWizard/MAE
# --------------------------------------------------------
import torch
import timm
import numpy as np
from einops import repeat, rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block
def random_indexes(size : int):
forward_indexes = np.arange(size)
np.random.shuffle(forward_indexes)
backward_indexes = np.argsort(forward_indexes)
return forward_indexes, backward_indexes
def take_indexes(sequences, indexes):
return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))
class PatchShuffle(torch.nn.Module):
def __init__(self, ratio) -> None:
super().__init__()
self.ratio = ratio
def forward(self, patches : torch.Tensor):
T, B, C = patches.shape
remain_T = int(T * (1 - self.ratio))
indexes = [random_indexes(T) for _ in range(B)]
forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
patches = take_indexes(patches, forward_indexes)
patches = patches[:remain_T]
return patches, forward_indexes, backward_indexes
class MAE_Encoder(torch.nn.Module):
def __init__(self,
image_size=32,
patch_size=2,
emb_dim=192,
num_layer=12,
num_head=3,
mask_ratio=0.75,
) -> None:
super().__init__()
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
self.shuffle = PatchShuffle(mask_ratio)
self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)
self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
self.layer_norm = torch.nn.LayerNorm(emb_dim)
self.init_weight()
def init_weight(self):
trunc_normal_(self.cls_token, std=.02)
trunc_normal_(self.pos_embedding, std=.02)
def forward(self, img):
patches = self.patchify(img)
patches = rearrange(patches, 'b c h w -> (h w) b c')
patches = patches + self.pos_embedding
patches, forward_indexes, backward_indexes = self.shuffle(patches)
patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
patches = rearrange(patches, 't b c -> b t c')
features = self.layer_norm(self.transformer(patches))
features = rearrange(features, 'b t c -> t b c')
return features, backward_indexes
class MAE_Decoder(torch.nn.Module):
def __init__(self,
image_size=32,
patch_size=2,
emb_dim=192,
num_layer=4,
num_head=3,
) -> None:
super().__init__()
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))
self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)
self.init_weight()
def init_weight(self):
trunc_normal_(self.mask_token, std=.02)
trunc_normal_(self.pos_embedding, std=.02)
def forward(self, features, backward_indexes):
T = features.shape[0]
backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
features = take_indexes(features, backward_indexes)
features = features + self.pos_embedding
features = rearrange(features, 't b c -> b t c')
features = self.transformer(features)
features = rearrange(features, 'b t c -> t b c')
features = features[1:] # remove global feature
patches = self.head(features)
mask = torch.zeros_like(patches)
mask[T-1:] = 1
mask = take_indexes(mask, backward_indexes[1:] - 1)
img = self.patch2img(patches)
mask = self.patch2img(mask)
return img, mask
class MAE_ViT(torch.nn.Module):
def __init__(self,
image_size=32,
patch_size=2,
emb_dim=192,
encoder_layer=12,
encoder_head=3,
decoder_layer=4,
decoder_head=3,
mask_ratio=0.75,
) -> None:
super().__init__()
self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)
def forward(self, img):
features, backward_indexes = self.encoder(img)
predicted_img, mask = self.decoder(features, backward_indexes)
return predicted_img, mask
class ViT_Classifier(torch.nn.Module):
'''
A simple image classification task acts as a head for ViT, allowing fine-tuning on downstream tasks.
We didn't directly use the MAE_ViT encoder because we need to add a classification head.
The Masked Autoencoder uses only some patches as input, which means it lacks the global information of the image,
making it unsuitable for classification.
'''
def __init__(self, encoder : MAE_Encoder, dropout_p, num_classes=10) -> None:
super().__init__()
self.dropout_p = dropout_p
self.cls_token = encoder.cls_token
self.pos_embedding = encoder.pos_embedding
self.patchify = encoder.patchify
self.transformer = encoder.transformer
self.layer_norm = encoder.layer_norm
self.dropout = torch.nn.Dropout(dropout_p) # Add dropout layer
self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)
def forward(self, img):
patches = self.patchify(img)
patches = rearrange(patches, 'b c h w -> (h w) b c')
patches = patches + self.pos_embedding
patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
patches = rearrange(patches, 't b c -> b t c')
features = self.layer_norm(self.transformer(patches))
# t is the number of patches, b is the batch size, c is the number of features
features = rearrange(features, 'b t c -> t b c')
if self.dropout_p > 0:
features = self.dropout(features) # Apply dropout before the final head
logits = self.head(features[0]) # only use the cls token
return logits
class MAE_Encoder_FeatureExtractor(torch.nn.Module):
'''
A feature extractor that extracts features from the encoder of the Masked Autoencoder.
'''
def __init__(self, encoder : MAE_Encoder) -> None:
super().__init__()
self.cls_token = encoder.cls_token
self.pos_embedding = encoder.pos_embedding
self.patchify = encoder.patchify
self.transformer = encoder.transformer
self.layer_norm = encoder.layer_norm
def forward(self, img):
patches = self.patchify(img)
patches = rearrange(patches, 'b c h w -> (h w) b c')
patches = patches + self.pos_embedding
patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
patches = rearrange(patches, 't b c -> b t c')
features = self.layer_norm(self.transformer(patches))
# t is the number of patches, b is the batch size, c is the number of features
features = rearrange(features, 'b t c -> t b c')
return features
if __name__ == '__main__':
shuffle = PatchShuffle(0.75)
a = torch.rand(16, 2, 10)
b, forward_indexes, backward_indexes = shuffle(a)
print(b.shape)
img = torch.rand(2, 3, 32, 32)
encoder = MAE_Encoder()
decoder = MAE_Decoder()
features, backward_indexes = encoder(img)
print(forward_indexes.shape)
predicted_img, mask = decoder(features, backward_indexes)
print(predicted_img.shape)
loss = torch.mean((predicted_img - img) ** 2 * mask / 0.75)
print(loss)