Spaces:
Runtime error
Runtime error
File size: 6,939 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
from ..backbones.vision_transformer import TransformerEncoderLayer
from ..utils import build_2d_sincos_position_embedding
@MODELS.register_module()
class MAEPretrainDecoder(BaseModule):
"""Decoder for MAE Pre-training.
Some of the code is borrowed from `https://github.com/facebookresearch/mae`. # noqa
Args:
num_patches (int): The number of total patches. Defaults to 196.
patch_size (int): Image patch size. Defaults to 16.
in_chans (int): The channel of input image. Defaults to 3.
embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
decoder_embed_dim (int): Decoder's embedding dimension.
Defaults to 512.
decoder_depth (int): The depth of decoder. Defaults to 8.
decoder_num_heads (int): Number of attention heads of decoder.
Defaults to 16.
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
Defaults to 4.
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
Example:
>>> from mmpretrain.models import MAEPretrainDecoder
>>> import torch
>>> self = MAEPretrainDecoder()
>>> self.eval()
>>> inputs = torch.rand(1, 50, 1024)
>>> ids_restore = torch.arange(0, 196).unsqueeze(0)
>>> level_outputs = self.forward(inputs, ids_restore)
>>> print(tuple(level_outputs.shape))
(1, 196, 768)
"""
def __init__(self,
num_patches: int = 196,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 1024,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
mlp_ratio: int = 4,
norm_cfg: dict = dict(type='LN', eps=1e-6),
predict_feature_dim: Optional[float] = None,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.num_patches = num_patches
# used to convert the dim of features from encoder to the dim
# compatible with that of decoder
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
# create new position embedding, different from that in encoder
# and is not learnable
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1, decoder_embed_dim),
requires_grad=False)
self.decoder_blocks = nn.ModuleList([
TransformerEncoderLayer(
decoder_embed_dim,
decoder_num_heads,
int(mlp_ratio * decoder_embed_dim),
qkv_bias=True,
norm_cfg=norm_cfg) for _ in range(decoder_depth)
])
self.decoder_norm_name, decoder_norm = build_norm_layer(
norm_cfg, decoder_embed_dim, postfix=1)
self.add_module(self.decoder_norm_name, decoder_norm)
# Used to map features to pixels
if predict_feature_dim is None:
predict_feature_dim = patch_size**2 * in_chans
self.decoder_pred = nn.Linear(
decoder_embed_dim, predict_feature_dim, bias=True)
def init_weights(self) -> None:
"""Initialize position embedding and mask token of MAE decoder."""
super().init_weights()
decoder_pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.decoder_pos_embed.shape[-1],
cls_token=True)
self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
torch.nn.init.normal_(self.mask_token, std=.02)
@property
def decoder_norm(self):
"""The normalization layer of decoder."""
return getattr(self, self.decoder_norm_name)
def forward(self, x: torch.Tensor,
ids_restore: torch.Tensor) -> torch.Tensor:
"""The forward function.
The process computes the visible patches' features vectors and the mask
tokens to output feature vectors, which will be used for
reconstruction.
Args:
x (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
ids_restore (torch.Tensor): ids to restore original image.
Returns:
torch.Tensor: The reconstructed feature vectors, which is of
shape B x (num_patches) x C.
"""
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
x_,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
return x
@MODELS.register_module()
class ClsBatchNormNeck(BaseModule):
"""Normalize cls token across batch before head.
This module is proposed by MAE, when running linear probing.
Args:
input_features (int): The dimension of features.
affine (bool): a boolean value that when set to ``True``, this module
has learnable affine parameters. Defaults to False.
eps (float): a value added to the denominator for numerical stability.
Defaults to 1e-6.
init_cfg (Dict or List[Dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
input_features: int,
affine: bool = False,
eps: float = 1e-6,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(init_cfg)
self.bn = nn.BatchNorm1d(input_features, affine=affine, eps=eps)
def forward(
self,
inputs: Tuple[List[torch.Tensor]]) -> Tuple[List[torch.Tensor]]:
"""The forward function."""
# Only apply batch norm to cls_token
inputs = [self.bn(input_) for input_ in inputs]
return tuple(inputs)
|