# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from mmcv.cnn import build_norm_layer from mmengine.model import BaseModule from mmpretrain.registry import MODELS from ..backbones.vision_transformer import TransformerEncoderLayer from ..utils import build_2d_sincos_position_embedding @MODELS.register_module() class MAEPretrainDecoder(BaseModule): """Decoder for MAE Pre-training. Some of the code is borrowed from `https://github.com/facebookresearch/mae`. # noqa Args: num_patches (int): The number of total patches. Defaults to 196. patch_size (int): Image patch size. Defaults to 16. in_chans (int): The channel of input image. Defaults to 3. embed_dim (int): Encoder's embedding dimension. Defaults to 1024. decoder_embed_dim (int): Decoder's embedding dimension. Defaults to 512. decoder_depth (int): The depth of decoder. Defaults to 8. decoder_num_heads (int): Number of attention heads of decoder. Defaults to 16. mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. Defaults to 4. norm_cfg (dict): Normalization layer. Defaults to LayerNorm. init_cfg (Union[List[dict], dict], optional): Initialization config dict. Defaults to None. Example: >>> from mmpretrain.models import MAEPretrainDecoder >>> import torch >>> self = MAEPretrainDecoder() >>> self.eval() >>> inputs = torch.rand(1, 50, 1024) >>> ids_restore = torch.arange(0, 196).unsqueeze(0) >>> level_outputs = self.forward(inputs, ids_restore) >>> print(tuple(level_outputs.shape)) (1, 196, 768) """ def __init__(self, num_patches: int = 196, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 1024, decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, mlp_ratio: int = 4, norm_cfg: dict = dict(type='LN', eps=1e-6), predict_feature_dim: Optional[float] = None, init_cfg: Optional[Union[List[dict], dict]] = None) -> None: super().__init__(init_cfg=init_cfg) self.num_patches = num_patches # used to convert the dim of features from encoder to the dim # compatible with that of decoder self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) # create new position embedding, different from that in encoder # and is not learnable self.decoder_pos_embed = nn.Parameter( torch.zeros(1, self.num_patches + 1, decoder_embed_dim), requires_grad=False) self.decoder_blocks = nn.ModuleList([ TransformerEncoderLayer( decoder_embed_dim, decoder_num_heads, int(mlp_ratio * decoder_embed_dim), qkv_bias=True, norm_cfg=norm_cfg) for _ in range(decoder_depth) ]) self.decoder_norm_name, decoder_norm = build_norm_layer( norm_cfg, decoder_embed_dim, postfix=1) self.add_module(self.decoder_norm_name, decoder_norm) # Used to map features to pixels if predict_feature_dim is None: predict_feature_dim = patch_size**2 * in_chans self.decoder_pred = nn.Linear( decoder_embed_dim, predict_feature_dim, bias=True) def init_weights(self) -> None: """Initialize position embedding and mask token of MAE decoder.""" super().init_weights() decoder_pos_embed = build_2d_sincos_position_embedding( int(self.num_patches**.5), self.decoder_pos_embed.shape[-1], cls_token=True) self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) torch.nn.init.normal_(self.mask_token, std=.02) @property def decoder_norm(self): """The normalization layer of decoder.""" return getattr(self, self.decoder_norm_name) def forward(self, x: torch.Tensor, ids_restore: torch.Tensor) -> torch.Tensor: """The forward function. The process computes the visible patches' features vectors and the mask tokens to output feature vectors, which will be used for reconstruction. Args: x (torch.Tensor): hidden features, which is of shape B x (L * mask_ratio) x C. ids_restore (torch.Tensor): ids to restore original image. Returns: torch.Tensor: The reconstructed feature vectors, which is of shape B x (num_patches) x C. """ # embed tokens x = self.decoder_embed(x) # append mask tokens to sequence mask_tokens = self.mask_token.repeat( x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) x_ = torch.gather( x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) x = torch.cat([x[:, :1, :], x_], dim=1) # add pos embed x = x + self.decoder_pos_embed # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) # predictor projection x = self.decoder_pred(x) # remove cls token x = x[:, 1:, :] return x @MODELS.register_module() class ClsBatchNormNeck(BaseModule): """Normalize cls token across batch before head. This module is proposed by MAE, when running linear probing. Args: input_features (int): The dimension of features. affine (bool): a boolean value that when set to ``True``, this module has learnable affine parameters. Defaults to False. eps (float): a value added to the denominator for numerical stability. Defaults to 1e-6. init_cfg (Dict or List[Dict], optional): Config dict for weight initialization. Defaults to None. """ def __init__(self, input_features: int, affine: bool = False, eps: float = 1e-6, init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: super().__init__(init_cfg) self.bn = nn.BatchNorm1d(input_features, affine=affine, eps=eps) def forward( self, inputs: Tuple[List[torch.Tensor]]) -> Tuple[List[torch.Tensor]]: """The forward function.""" # Only apply batch norm to cls_token inputs = [self.bn(input_) for input_ in inputs] return tuple(inputs)