# -------------------------------------------------------- # References: # MAE: https://github.com/IcarusWizard/MAE # -------------------------------------------------------- import torch import timm import numpy as np from einops import repeat, rearrange from einops.layers.torch import Rearrange from timm.models.layers import trunc_normal_ from timm.models.vision_transformer import Block def random_indexes(size : int): forward_indexes = np.arange(size) np.random.shuffle(forward_indexes) backward_indexes = np.argsort(forward_indexes) return forward_indexes, backward_indexes def take_indexes(sequences, indexes): return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1])) class PatchShuffle(torch.nn.Module): def __init__(self, ratio) -> None: super().__init__() self.ratio = ratio def forward(self, patches : torch.Tensor): T, B, C = patches.shape remain_T = int(T * (1 - self.ratio)) indexes = [random_indexes(T) for _ in range(B)] forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device) backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device) patches = take_indexes(patches, forward_indexes) patches = patches[:remain_T] return patches, forward_indexes, backward_indexes class MAE_Encoder(torch.nn.Module): def __init__(self, image_size=32, patch_size=2, emb_dim=192, num_layer=12, num_head=3, mask_ratio=0.75, ) -> None: super().__init__() self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim)) self.shuffle = PatchShuffle(mask_ratio) self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size) self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)]) self.layer_norm = torch.nn.LayerNorm(emb_dim) self.init_weight() def init_weight(self): trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.pos_embedding, std=.02) def forward(self, img): patches = self.patchify(img) patches = rearrange(patches, 'b c h w -> (h w) b c') patches = patches + self.pos_embedding patches, forward_indexes, backward_indexes = self.shuffle(patches) patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0) patches = rearrange(patches, 't b c -> b t c') features = self.layer_norm(self.transformer(patches)) features = rearrange(features, 'b t c -> t b c') return features, backward_indexes class MAE_Decoder(torch.nn.Module): def __init__(self, image_size=32, patch_size=2, emb_dim=192, num_layer=4, num_head=3, ) -> None: super().__init__() self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim)) self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)]) self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2) self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size) self.init_weight() def init_weight(self): trunc_normal_(self.mask_token, std=.02) trunc_normal_(self.pos_embedding, std=.02) def forward(self, features, backward_indexes): T = features.shape[0] backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0) features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0) features = take_indexes(features, backward_indexes) features = features + self.pos_embedding features = rearrange(features, 't b c -> b t c') features = self.transformer(features) features = rearrange(features, 'b t c -> t b c') features = features[1:] # remove global feature patches = self.head(features) mask = torch.zeros_like(patches) mask[T-1:] = 1 mask = take_indexes(mask, backward_indexes[1:] - 1) img = self.patch2img(patches) mask = self.patch2img(mask) return img, mask class MAE_ViT(torch.nn.Module): def __init__(self, image_size=32, patch_size=2, emb_dim=192, encoder_layer=12, encoder_head=3, decoder_layer=4, decoder_head=3, mask_ratio=0.75, ) -> None: super().__init__() self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio) self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head) def forward(self, img): features, backward_indexes = self.encoder(img) predicted_img, mask = self.decoder(features, backward_indexes) return predicted_img, mask class ViT_Classifier(torch.nn.Module): ''' A simple image classification task acts as a head for ViT, allowing fine-tuning on downstream tasks. We didn't directly use the MAE_ViT encoder because we need to add a classification head. The Masked Autoencoder uses only some patches as input, which means it lacks the global information of the image, making it unsuitable for classification. ''' def __init__(self, encoder : MAE_Encoder, dropout_p, num_classes=10) -> None: super().__init__() self.dropout_p = dropout_p self.cls_token = encoder.cls_token self.pos_embedding = encoder.pos_embedding self.patchify = encoder.patchify self.transformer = encoder.transformer self.layer_norm = encoder.layer_norm self.dropout = torch.nn.Dropout(dropout_p) # Add dropout layer self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes) def forward(self, img): patches = self.patchify(img) patches = rearrange(patches, 'b c h w -> (h w) b c') patches = patches + self.pos_embedding patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0) patches = rearrange(patches, 't b c -> b t c') features = self.layer_norm(self.transformer(patches)) # t is the number of patches, b is the batch size, c is the number of features features = rearrange(features, 'b t c -> t b c') if self.dropout_p > 0: features = self.dropout(features) # Apply dropout before the final head logits = self.head(features[0]) # only use the cls token return logits class MAE_Encoder_FeatureExtractor(torch.nn.Module): ''' A feature extractor that extracts features from the encoder of the Masked Autoencoder. ''' def __init__(self, encoder : MAE_Encoder) -> None: super().__init__() self.cls_token = encoder.cls_token self.pos_embedding = encoder.pos_embedding self.patchify = encoder.patchify self.transformer = encoder.transformer self.layer_norm = encoder.layer_norm def forward(self, img): patches = self.patchify(img) patches = rearrange(patches, 'b c h w -> (h w) b c') patches = patches + self.pos_embedding patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0) patches = rearrange(patches, 't b c -> b t c') features = self.layer_norm(self.transformer(patches)) # t is the number of patches, b is the batch size, c is the number of features features = rearrange(features, 'b t c -> t b c') return features if __name__ == '__main__': shuffle = PatchShuffle(0.75) a = torch.rand(16, 2, 10) b, forward_indexes, backward_indexes = shuffle(a) print(b.shape) img = torch.rand(2, 3, 32, 32) encoder = MAE_Encoder() decoder = MAE_Decoder() features, backward_indexes = encoder(img) print(forward_indexes.shape) predicted_img, mask = decoder(features, backward_indexes) print(predicted_img.shape) loss = torch.mean((predicted_img - img) ** 2 * mask / 0.75) print(loss)