import torch from mmcls.models import VisionTransformer from torch import nn from torch.utils.checkpoint import checkpoint import copy def build_2d_sincos_position_embedding(patches_resolution, embed_dims, temperature=10000., cls_token=False): """The function is to build position embedding for model to obtain the position information of the image patches.""" if isinstance(patches_resolution, int): patches_resolution = (patches_resolution, patches_resolution) h, w = patches_resolution grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h) assert embed_dims % 4 == 0, \ 'Embed dimension must be divisible by 4.' pos_dim = embed_dims // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1. / (temperature**omega) out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) pos_emb = torch.cat( [ torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h) ], dim=1, )[None, :, :] if cls_token: cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32) pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1) return pos_emb class MAEViT(VisionTransformer): """Vision Transformer for MAE pre-training. A PyTorch implement of: `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_ Args: arch (str | dict): Vision Transformer architecture Default: 'b' img_size (int | tuple): Input image size patch_size (int | tuple): The patch size out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. output_cls_token (bool): Whether output the cls_token. If set True, `with_cls_token` must be True. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. mask_ratio (bool): The ratio of total number of patches to be masked. Defaults to 0.75. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ arch_zoo = { **dict.fromkeys( ['mocov3-s', 'mocov3-small'], { 'embed_dims': 384, 'num_layers': 12, 'num_heads': 12, 'feedforward_channels': 1536, }), **dict.fromkeys( ['b', 'base'], { 'embed_dims': 768, 'num_layers': 12, 'num_heads': 12, 'feedforward_channels': 3072 }), } def __init__(self, arch='b', img_size=224, patch_size=16, out_indices=-1, drop_rate=0, drop_path_rate=0, norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, output_cls_token=False, interpolate_mode='bicubic', patch_cfg=dict(), layer_cfgs=dict(), gradientCKPT=False, mask_ratio=0.75, init_cfg=None): super().__init__( arch=arch, img_size=img_size, patch_size=patch_size, out_indices=out_indices, drop_rate=drop_rate, drop_path_rate=drop_path_rate, norm_cfg=norm_cfg, final_norm=final_norm, output_cls_token=output_cls_token, interpolate_mode=interpolate_mode, patch_cfg=patch_cfg, layer_cfgs=layer_cfgs, init_cfg=init_cfg) self.gradientCKPT = gradientCKPT self.pos_embed.requires_grad = False self.mask_ratio = mask_ratio self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] # self.mask_embedding = copy.deepcopy(self.patch_embed) # self.mask_embedding.norm = None def init_weights(self): super(MAEViT, self).init_weights() if not (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # initialize position embedding in backbone pos_embed = build_2d_sincos_position_embedding( self.patch_resolution, self.pos_embed.shape[-1], cls_token=True) self.pos_embed.data.copy_(pos_embed.float()) w = self.patch_embed.projection.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) torch.nn.init.normal_(self.cls_token, std=.02) self.apply(self._init_weights) # mask_embedding transfers pixel level mask to token level # self.mask_embedding.apply(self._init_mask_embedding) # for para in self.mask_embedding.parameters(): # para.requires_grad = False def _init_mask_embedding(self,m): if hasattr(m,'weight'): nn.init.constant_(m.weight,1.0) if hasattr(m, 'bias'): nn.init.constant_(m.bias,0) def _init_weights(self, m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def random_masking(self, x, mask_ratio=0.75, attn_mask=None): """Generate the mask for MAE Pre-training. Args: x (torch.tensor): Image with data augmentation applied. mask_ratio (float): The mask ratio of total patches. Defaults to 0.75. Returns: tuple[Tensor, Tensor, Tensor]: masked image, mask and the ids to restore original image. - x_masked (Tensor): masked image. - mask (Tensor): mask used to mask image. - ids_restore (Tensor): ids to restore original image. """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort( noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather( x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # modified_attn_mask = None if attn_mask is None else torch.gather(attn_mask,dim=1, index=ids_keep) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore #, modified_attn_mask def generate_mask(self, pixel_level_attn_mask): ''' pixel_level_attn_mask: (0,1) attn mask with the same shape as img ''' if pixel_level_attn_mask is None: return None # H, W = patch_resolution # B, C = pixel_level_attn_mask.shape[:2] # attn_mask = torch.ones((B,C,H,W),device=pixel_level_attn_mask) # H_splited = torch.chunk(pixel_level_attn_mask, H, -2) # HW_splited_mask = (torch.chunk(Hs, W, -1) for Hs in H_splited) # if HW_splited_mask[:,:,hi,wi].sum().item() == 0: # attn_mask[:,:,hi,wi] = 0 # mask_patches = self.mask_embedding(pixel_level_attn_mask)[0] # attn_mask = mask_patches.sum(-1) != 0 # return attn_mask def extract_feat(self, img ,attn_mask=None): x, *_ = self.forward(img,attn_mask) if self.output_cls_token: return x[:,0,:] else: return torch.mean(x,dim=1) def forward(self, x, attn_mask=None): if attn_mask is not None: assert self.output_cls_token B = x.shape[0] x = self.patch_embed(x)[0] # add pos embed w/o cls token x = x + self.pos_embed[:, 1:1+x.shape[1], :] # masking: length -> length * mask_ratio if True: assert self.mask_ratio == 0. else: x, mask, ids_restore = self.random_masking(x, self.mask_ratio) # append cls token cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = self.drop_after_pos(x) # if attn_mask is not None: # attn_mask = torch.concat((torch.ones((B,1),device=attn_mask.device) , attn_mask),dim=1) for i, layer in enumerate(self.layers): if self.gradientCKPT: x = checkpoint(layer,x) # ,attn_mask else: x = layer(x) # ,attn_mask if i == len(self.layers) - 1 and self.final_norm: x = self.norm1(x) if True: return x else: return (x, mask, ids_restore) def forward_generator(self, x, attn_mask=None): if attn_mask is not None: assert self.output_cls_token B = x.shape[0] x = self.patch_embed(x)[0] # add pos embed w/o cls token x = x + self.pos_embed[:, 1:1+x.shape[1], :] # append cls token cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = self.drop_after_pos(x) for i, layer in enumerate(self.layers): if self.gradientCKPT: x = checkpoint(layer,x) # ,attn_mask else: x = layer(x) # ,attn_mask if i == len(self.layers) - 1 and self.final_norm: x = self.norm1(x) x = x if (new_x:=(yield x)) is None else new_x debug = False if debug: print(f'layer {i}-th forwarded')