Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Optional, Union | |
import torch | |
import torch.nn as nn | |
from mmpretrain.registry import MODELS | |
from ..utils import build_2d_sincos_position_embedding | |
from .mae_neck import MAEPretrainDecoder | |
class MixMIMPretrainDecoder(MAEPretrainDecoder): | |
"""Decoder for MixMIM Pretraining. | |
Some of the code is borrowed from `https://github.com/Sense-X/MixMIM`. # noqa | |
Args: | |
num_patches (int): The number of total patches. Defaults to 196. | |
patch_size (int): Image patch size. Defaults to 16. | |
in_chans (int): The channel of input image. Defaults to 3. | |
embed_dim (int): Encoder's embedding dimension. Defaults to 1024. | |
encoder_stride (int): The output stride of MixMIM backbone. Defaults | |
to 32. | |
decoder_embed_dim (int): Decoder's embedding dimension. | |
Defaults to 512. | |
decoder_depth (int): The depth of decoder. Defaults to 8. | |
decoder_num_heads (int): Number of attention heads of decoder. | |
Defaults to 16. | |
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. | |
Defaults to 4. | |
norm_cfg (dict): Normalization layer. Defaults to LayerNorm. | |
init_cfg (Union[List[dict], dict], optional): Initialization config | |
dict. Defaults to None. | |
""" | |
def __init__(self, | |
num_patches: int = 196, | |
patch_size: int = 16, | |
in_chans: int = 3, | |
embed_dim: int = 1024, | |
encoder_stride: int = 32, | |
decoder_embed_dim: int = 512, | |
decoder_depth: int = 8, | |
decoder_num_heads: int = 16, | |
mlp_ratio: int = 4, | |
norm_cfg: dict = dict(type='LN', eps=1e-6), | |
init_cfg: Optional[Union[List[dict], dict]] = None) -> None: | |
super().__init__( | |
num_patches=num_patches, | |
patch_size=patch_size, | |
in_chans=in_chans, | |
embed_dim=embed_dim, | |
decoder_embed_dim=decoder_embed_dim, | |
decoder_depth=decoder_depth, | |
decoder_num_heads=decoder_num_heads, | |
mlp_ratio=mlp_ratio, | |
norm_cfg=norm_cfg, | |
init_cfg=init_cfg) | |
self.decoder_pos_embed = nn.Parameter( | |
torch.zeros(1, num_patches, decoder_embed_dim), | |
requires_grad=False) | |
self.decoder_pred = nn.Linear(decoder_embed_dim, encoder_stride**2 * 3) | |
def init_weights(self) -> None: | |
"""Initialize position embedding and mask token of MixMIM decoder.""" | |
super(MAEPretrainDecoder, self).init_weights() | |
decoder_pos_embed = build_2d_sincos_position_embedding( | |
int(self.num_patches**.5), | |
self.decoder_pos_embed.shape[-1], | |
cls_token=False) | |
self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) | |
torch.nn.init.normal_(self.mask_token, std=.02) | |
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
"""Forward function. | |
Args: | |
x (torch.Tensor): The input features, which is of shape (N, L, C). | |
mask (torch.Tensor): The tensor to indicate which tokens a | |
re masked. | |
Returns: | |
torch.Tensor: The reconstructed features, which is of shape | |
(N, L, C). | |
""" | |
x = self.decoder_embed(x) | |
B, L, C = x.shape | |
mask_tokens = self.mask_token.expand(B, L, -1) | |
x1 = x * (1 - mask) + mask_tokens * mask | |
x2 = x * mask + mask_tokens * (1 - mask) | |
x = torch.cat([x1, x2], dim=0) | |
# add pos embed | |
x = x + self.decoder_pos_embed | |
# apply Transformer blocks | |
for idx, blk in enumerate(self.decoder_blocks): | |
x = blk(x) | |
x = self.decoder_norm(x) | |
# predictor projection | |
x = self.decoder_pred(x) | |
return x | |