# Copyright (c) OpenMMLab. All rights reserved. import random from typing import Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from mmpretrain.models.backbones import MixMIMTransformer from mmpretrain.registry import MODELS from mmpretrain.structures import DataSample from ..utils import build_2d_sincos_position_embedding from .base import BaseSelfSupervisor @MODELS.register_module() class MixMIMPretrainTransformer(MixMIMTransformer): """MixMIM backbone for MixMIM pre-training. A PyTorch implement of : ` MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning `_ Args: arch (str | dict): MixMIM architecture. If use string, choose from 'base','large' and 'huge'. If use dict, it should have below keys: - **embed_dims** (int): The dimensions of embedding. - **depths** (int): The number of transformer encoder layers. - **num_heads** (int): The number of heads in attention modules. Defaults to 'base'. mlp_ratio (int): The mlp ratio in FFN. Defaults to 4. img_size (int | tuple): The expected input image shape. Because we support dynamic input shape, just set the argument to mlp_ratio the most common input image shape. Defaults to 224. patch_size (int | tuple): The patch size in patch embedding. Defaults to 16. in_channels (int): The num of input channels. Defaults to 3. window_size (list): The height and width of the window. qkv_bias (bool): Whether to add bias for qkv in attention modules. Defaults to True. patch_cfg (dict): Extra config dict for patch embedding. Defaults to an empty dict. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0. attn_drop_rate (float): Attention drop rate. Defaults to 0. use_checkpoint (bool): Whether use the checkpoint to reduce GPU memory cost. Defaults to False. mask_ratio (bool): The base ratio of total number of patches to be masked. Defaults to 0.5. range_mask_ratio (float): The range of mask ratio. Defaults to 0. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ def __init__(self, arch: Union[str, dict] = 'base', mlp_ratio: float = 4, img_size: int = 224, patch_size: int = 4, in_channels: int = 3, window_size: List = [14, 14, 14, 7], qkv_bias: bool = True, patch_cfg: dict = dict(), norm_cfg: dict = dict(type='LN'), drop_rate: float = 0.0, drop_path_rate: float = 0.0, attn_drop_rate: float = 0.0, use_checkpoint: bool = False, mask_ratio: float = 0.5, range_mask_ratio: float = 0.0, init_cfg: Optional[dict] = None) -> None: super().__init__( arch=arch, mlp_ratio=mlp_ratio, img_size=img_size, patch_size=patch_size, in_channels=in_channels, window_size=window_size, qkv_bias=qkv_bias, patch_cfg=patch_cfg, norm_cfg=norm_cfg, drop_rate=drop_rate, drop_path_rate=drop_path_rate, attn_drop_rate=attn_drop_rate, use_checkpoint=use_checkpoint, init_cfg=init_cfg) self.mask_ratio = mask_ratio self.range_mask_ratio = range_mask_ratio def init_weights(self): """Initialize position embedding, patch embedding.""" super(MixMIMTransformer, self).init_weights() pos_embed = build_2d_sincos_position_embedding( int(self.num_patches**.5), self.absolute_pos_embed.shape[-1], cls_token=False) self.absolute_pos_embed.data.copy_(pos_embed.float()) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def random_masking(self, x: torch.Tensor, mask_ratio: float = 0.5) -> Tuple[torch.Tensor]: """Generate the mask for MixMIM Pretraining. Args: x (torch.Tensor): Image with data augmentation applied, which is of shape B x L x C. mask_ratio (float): The mask ratio of total patches. Defaults to 0.5. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - mask_s1 (torch.Tensor): mask with stride of self.encoder_stride // 8. - mask_s2 (torch.Tensor): mask with stride of self.encoder_stride // 4. - mask_s3 (torch.Tensor): mask with stride of self.encoder_stride // 2. - mask (torch.Tensor): mask with stride of self.encoder_stride. """ B, C, H, W = x.shape out_H = H // self.encoder_stride out_W = W // self.encoder_stride s3_H, s3_W = out_H * 2, out_W * 2 s2_H, s2_W = out_H * 4, out_W * 4 s1_H, s1_W = out_H * 8, out_W * 8 seq_l = out_H * out_W # use a shared mask for a batch images mask = torch.zeros([1, 1, seq_l], device=x.device) mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio) noise = torch.rand(1, 1, seq_l, device=x.device) # noise in [0, 1] # ascend: small is keep, large is removed mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)] mask.scatter_(2, mask_idx, 1) mask = mask.reshape(1, 1, out_H, out_W) mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest') mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest') mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest') mask = mask.reshape(1, out_H * out_W, 1).contiguous() mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous() mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous() mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous() return mask_s1, mask_s2, mask_s3, mask def forward(self, x: torch.Tensor, mask: Optional[bool] = True) -> Tuple[torch.Tensor]: """Generate features for masked images. This function generates mask and masks some patches randomly and get the hidden features for visible patches. Args: x (torch.Tensor): Input images, which is of shape B x C x H x W. mask (bool, optional): To indicate whether the forward containing ``mask`` or not. Returns: Tuple[torch.Tensor, torch.Tensor]: - x (torch.Tensor): hidden features, which is of shape B x L x C. - mask_s4 (torch.Tensor): the mask tensor for the last layer. """ if mask is None or False: return super().forward(x) else: mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking( x, self.mask_ratio) x, _ = self.patch_embed(x) x = x * (1. - mask_s1) + x.flip(0) * mask_s1 x = x + self.absolute_pos_embed x = self.drop_after_pos(x) for idx, layer in enumerate(self.layers): if idx == 0: x = layer(x, attn_mask=mask_s1) elif idx == 1: x = layer(x, attn_mask=mask_s2) elif idx == 2: x = layer(x, attn_mask=mask_s3) elif idx == 3: x = layer(x, attn_mask=mask_s4) x = self.norm(x) return x, mask_s4 @MODELS.register_module() class MixMIM(BaseSelfSupervisor): """MixMIM. Implementation of `MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning. `_. """ def __init__(self, backbone: dict, neck: Optional[dict] = None, head: Optional[dict] = None, pretrained: Optional[str] = None, data_preprocessor: Optional[Union[dict, nn.Module]] = None, init_cfg: Optional[dict] = None): head.update(dict(patch_size=neck['encoder_stride'])) super().__init__( backbone=backbone, neck=neck, head=head, pretrained=pretrained, data_preprocessor=data_preprocessor, init_cfg=init_cfg) def extract_feat(self, inputs: torch.Tensor): return self.backbone(inputs, mask=None) def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], **kwargs) -> Dict[str, torch.Tensor]: """The forward function in training. Args: inputs (torch.Tensor): The input images. data_samples (List[DataSample]): All elements required during the forward function. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ latent, mask = self.backbone(inputs) x_rec = self.neck(latent, mask) loss = self.head.loss(x_rec, inputs, mask) losses = dict(loss=loss) return losses