# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- from functools import partial import torch import torch.nn as nn import timm.models.vision_transformer class VisionTransformer(timm.models.vision_transformer.VisionTransformer): """Vision Transformer with support for global average pooling""" def __init__( self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs ): super(VisionTransformer, self).__init__(**kwargs) self.global_pool = global_pool if self.global_pool: norm_layer = kwargs["norm_layer"] embed_dim = kwargs["embed_dim"] self.fc_norm = norm_layer(embed_dim) del self.norm # remove the original norm self.mask_2d = mask_2d self.use_custom_patch = use_custom_patch def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) x = x + self.pos_embed[:, 1:, :] cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand( B, -1, -1 ) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = self.pos_drop(x) for blk in self.blocks: x = blk(x) if self.global_pool: x = x[:, 1:, :].mean(dim=1) # global pool without cls token outcome = self.fc_norm(x) else: x = self.norm(x) outcome = x[:, 0] return outcome def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort( noise, dim=1 ) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore def random_masking_2d(self, x, mask_t_prob, mask_f_prob): """ 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim if self.use_custom_patch: # # for AS T = 101 # 64,101 F = 12 # 8,12 # # for ESC # T=50 # F=12 # for SPC # T=12 # F=12 else: # ## for AS T = 64 F = 8 # ## for ESC # T=32 # F=8 ## for SPC # T=8 # F=8 # mask T x = x.reshape(N, T, F, D) len_keep_T = int(T * (1 - mask_t_prob)) noise = torch.rand(N, T, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort( noise, dim=1 ) # ascend: small is keep, large is remove ids_keep = ids_shuffle[:, :len_keep_T] index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D) # x_masked = torch.gather(x, dim=1, index=index) # x_masked = x_masked.reshape(N,len_keep_T*F,D) x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D # mask F # x = x.reshape(N, T, F, D) x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D len_keep_F = int(F * (1 - mask_f_prob)) noise = torch.rand(N, F, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort( noise, dim=1 ) # ascend: small is keep, large is remove ids_keep = ids_shuffle[:, :len_keep_F] # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D) index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D) x_masked = torch.gather(x, dim=1, index=index) x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D # x_masked = x_masked.reshape(N,len_keep*T,D) x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D) return x_masked, None, None def forward_features_mask(self, x, mask_t_prob, mask_f_prob): B = x.shape[0] # 4,1,1024,128 x = self.patch_embed(x) # 4, 512, 768 x = x + self.pos_embed[:, 1:, :] if self.random_masking_2d: x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob) else: x, mask, ids_restore = self.random_masking(x, mask_t_prob) cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = self.pos_drop(x) # apply Transformer blocks for blk in self.blocks: x = blk(x) if self.global_pool: x = x[:, 1:, :].mean(dim=1) # global pool without cls token outcome = self.fc_norm(x) else: x = self.norm(x) outcome = x[:, 0] return outcome # overwrite original timm def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0): if mask_t_prob > 0.0 or mask_f_prob > 0.0: x = self.forward_features_mask( x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob ) else: x = self.forward_features(x) x = self.head(x) return x def vit_small_patch16(**kwargs): model = VisionTransformer( patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) return model def vit_base_patch16(**kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) return model def vit_large_patch16(**kwargs): model = VisionTransformer( patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) return model def vit_huge_patch14(**kwargs): model = VisionTransformer( patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) return model