# -*- coding: utf-8 -*- # @Time : 6/10/21 5:04 PM # @Author : Yuan Gong # @Affiliation : Massachusetts Institute of Technology # @Email : yuangong@mit.edu # @File : ast_models.py import torch import torch.nn as nn from torch.cuda.amp import autocast import os # import wget os.environ['TORCH_HOME'] = '../../pretrained_models' import timm from timm.models.layers import to_2tuple,trunc_normal_ # override the timm package to relax the input shape constraint. class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x).flatten(2).transpose(1, 2) return x class ASTModel(nn.Module): """ The AST model. :param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35 :param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6 :param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6 :param input_fdim: the number of frequency bins of the input spectrogram :param input_tdim: the number of time frames of the input spectrogram :param imagenet_pretrain: if use ImageNet pretrained model :param audioset_pretrain: if use full AudioSet and ImageNet pretrained model :param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining. """ def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True, return_hidden_state=None, pretrained_model=None): super(ASTModel, self).__init__() # assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.' if verbose == True: print('---------------AST Model Summary---------------') print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain))) # override timm input shape restriction timm.models.vision_transformer.PatchEmbed = PatchEmbed timm.models.layers.patch_embed.PatchEmbed = PatchEmbed # if AudioSet pretraining is not used (but ImageNet pretraining may still apply) if audioset_pretrain == False: if model_size == 'tiny224': self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain) elif model_size == 'small224': self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain) elif model_size == 'base224': self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain) elif model_size == 'base384': self.v = timm.create_model('deit_base_distilled_patch16_384', pretrained=imagenet_pretrain) else: raise Exception('Model size must be one of tiny224, small224, base224, base384.') tmp = PatchEmbed(img_size=self.v.patch_embed.img_size, patch_size=self.v.patch_embed.patch_size, in_chans=3, embed_dim=768) tmp.load_state_dict(self.v.patch_embed.state_dict()) self.v.patch_embed = tmp # self.v.patch_embed = PatchEmbed(img_size=self.v.patch_embed.img_size, patch_size=self.v.patch_embed.patch_size, # in_chans=3, embed_dim=768) self.original_num_patches = self.v.patch_embed.num_patches self.oringal_hw = int(self.original_num_patches ** 0.5) self.original_embedding_dim = self.v.pos_embed.shape[2] self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim)) # automatcially get the intermediate shape f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim) num_patches = f_dim * t_dim self.v.patch_embed.num_patches = num_patches if verbose == True: print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride)) print('number of patches={:d}'.format(num_patches)) # the linear projection layer new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride)) if imagenet_pretrain == True: new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1)) new_proj.bias = self.v.patch_embed.proj.bias self.v.patch_embed.proj = new_proj # the positional embedding if imagenet_pretrain == True: # get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24). new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw) # cut (from middle) or interpolate the second dimension of the positional embedding if t_dim <= self.oringal_hw: new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim] else: new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear') # cut (from middle) or interpolate the first dimension of the positional embedding if f_dim <= self.oringal_hw: new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :] else: new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear') # flatten the positional embedding new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2) # concatenate the above positional embedding with the cls token and distillation token of the deit model. self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1)) else: # if not use imagenet pretrained model, just randomly initialize a learnable positional embedding # TODO can use sinusoidal positional embedding instead new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim)) self.v.pos_embed = new_pos_embed trunc_normal_(self.v.pos_embed, std=.02) # now load a model that is pretrained on both ImageNet and AudioSet elif audioset_pretrain == True: if audioset_pretrain == True and imagenet_pretrain == False: raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.') if model_size != 'base384': raise ValueError('currently only has base384 AudioSet pretrained model.') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if os.path.exists('../../pretrained_models/audioset_10_10_0.4593.pth') == False: # # this model performs 0.4593 mAP on the audioset eval set # audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1' # wget.download(audioset_mdl_url, out='../../pretrained_models/audioset_10_10_0.4593.pth') sd = torch.load(pretrained_model, map_location=device) audio_model = ASTModel(label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False) audio_model = torch.nn.DataParallel(audio_model) audio_model.load_state_dict(sd, strict=False) self.v = audio_model.module.v self.original_embedding_dim = self.v.pos_embed.shape[2] self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim)) f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim) num_patches = f_dim * t_dim self.v.patch_embed.num_patches = num_patches # self.v.patch_embed.img_size = self.v.patch_embed.img_size if verbose == True: print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride)) print('number of patches={:d}'.format(num_patches)) new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, 1212, 768).transpose(1, 2).reshape(1, 768, 12, 101) # if the input sequence length is larger than the original audioset (10s), then cut the positional embedding if t_dim < 101: new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim] # otherwise interpolate else: new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear') if f_dim < 12: new_pos_embed = new_pos_embed[:, :, 6 - int(f_dim/2): 6 - int(f_dim/2) + f_dim, :] # otherwise interpolate elif f_dim > 12: new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear') new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2) self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1)) self.return_hidden_state = return_hidden_state def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024): test_input = torch.randn(1, 1, input_fdim, input_tdim) test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride)) test_out = test_proj(test_input) f_dim = test_out.shape[2] t_dim = test_out.shape[3] return f_dim, t_dim @autocast() def forward(self, x, external_features=None): """ :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) :return: prediction """ # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) all_hidden_states = () if self.return_hidden_state else None x = x.unsqueeze(1) x = x.transpose(2, 3) B = x.shape[0] x = self.v.patch_embed(x) cls_tokens = self.v.cls_token.expand(B, -1, -1) dist_token = self.v.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) x = x + self.v.pos_embed x = self.v.pos_drop(x) for blk in self.v.blocks: x = blk(x) if self.return_hidden_state: all_hidden_states = all_hidden_states + (self.v.norm(x),) x = self.v.norm(x) # x[:, 0] = (x[:, 0] + x[:, 1]) / 2 # x = (x[:, 0] + x[:, 1]) / 2 # x = self.mlp_head(x) if self.return_hidden_state: return x, all_hidden_states else: return x if __name__ == '__main__': input_tdim = 100 ast_mdl = ASTModel(input_tdim=input_tdim) # input a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins test_input = torch.rand([10, input_tdim, 128]) test_output = ast_mdl(test_input) # output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes. print(test_output.shape) input_tdim = 256 ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True) # input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins test_input = torch.rand([10, input_tdim, 128]) test_output = ast_mdl(test_input) # output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes. print(test_output.shape)