File size: 13,125 Bytes
3eb682b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
# -*- coding: utf-8 -*-
# @Time : 6/10/21 5:04 PM
# @Author : Yuan Gong
# @Affiliation : Massachusetts Institute of Technology
# @Email : yuangong@mit.edu
# @File : ast_models.py
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
import os
# import wget
os.environ['TORCH_HOME'] = '../../pretrained_models'
import timm
from timm.models.layers import to_2tuple,trunc_normal_
# override the timm package to relax the input shape constraint.
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class ASTModel(nn.Module):
"""
The AST model.
:param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35
:param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6
:param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6
:param input_fdim: the number of frequency bins of the input spectrogram
:param input_tdim: the number of time frames of the input spectrogram
:param imagenet_pretrain: if use ImageNet pretrained model
:param audioset_pretrain: if use full AudioSet and ImageNet pretrained model
:param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining.
"""
def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024,
imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True,
return_hidden_state=None, pretrained_model=None):
super(ASTModel, self).__init__()
# assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.'
if verbose == True:
print('---------------AST Model Summary---------------')
print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain)))
# override timm input shape restriction
timm.models.vision_transformer.PatchEmbed = PatchEmbed
timm.models.layers.patch_embed.PatchEmbed = PatchEmbed
# if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
if audioset_pretrain == False:
if model_size == 'tiny224':
self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain)
elif model_size == 'small224':
self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain)
elif model_size == 'base224':
self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain)
elif model_size == 'base384':
self.v = timm.create_model('deit_base_distilled_patch16_384', pretrained=imagenet_pretrain)
else:
raise Exception('Model size must be one of tiny224, small224, base224, base384.')
tmp = PatchEmbed(img_size=self.v.patch_embed.img_size, patch_size=self.v.patch_embed.patch_size,
in_chans=3, embed_dim=768)
tmp.load_state_dict(self.v.patch_embed.state_dict())
self.v.patch_embed = tmp
# self.v.patch_embed = PatchEmbed(img_size=self.v.patch_embed.img_size, patch_size=self.v.patch_embed.patch_size,
# in_chans=3, embed_dim=768)
self.original_num_patches = self.v.patch_embed.num_patches
self.oringal_hw = int(self.original_num_patches ** 0.5)
self.original_embedding_dim = self.v.pos_embed.shape[2]
self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))
# automatcially get the intermediate shape
f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
num_patches = f_dim * t_dim
self.v.patch_embed.num_patches = num_patches
if verbose == True:
print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
print('number of patches={:d}'.format(num_patches))
# the linear projection layer
new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
if imagenet_pretrain == True:
new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
new_proj.bias = self.v.patch_embed.proj.bias
self.v.patch_embed.proj = new_proj
# the positional embedding
if imagenet_pretrain == True:
# get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24).
new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw)
# cut (from middle) or interpolate the second dimension of the positional embedding
if t_dim <= self.oringal_hw:
new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim]
else:
new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear')
# cut (from middle) or interpolate the first dimension of the positional embedding
if f_dim <= self.oringal_hw:
new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :]
else:
new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
# flatten the positional embedding
new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2)
# concatenate the above positional embedding with the cls token and distillation token of the deit model.
self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
else:
# if not use imagenet pretrained model, just randomly initialize a learnable positional embedding
# TODO can use sinusoidal positional embedding instead
new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim))
self.v.pos_embed = new_pos_embed
trunc_normal_(self.v.pos_embed, std=.02)
# now load a model that is pretrained on both ImageNet and AudioSet
elif audioset_pretrain == True:
if audioset_pretrain == True and imagenet_pretrain == False:
raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.')
if model_size != 'base384':
raise ValueError('currently only has base384 AudioSet pretrained model.')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if os.path.exists('../../pretrained_models/audioset_10_10_0.4593.pth') == False:
# # this model performs 0.4593 mAP on the audioset eval set
# audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1'
# wget.download(audioset_mdl_url, out='../../pretrained_models/audioset_10_10_0.4593.pth')
sd = torch.load(pretrained_model, map_location=device)
audio_model = ASTModel(label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False)
audio_model = torch.nn.DataParallel(audio_model)
audio_model.load_state_dict(sd, strict=False)
self.v = audio_model.module.v
self.original_embedding_dim = self.v.pos_embed.shape[2]
self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))
f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
num_patches = f_dim * t_dim
self.v.patch_embed.num_patches = num_patches
# self.v.patch_embed.img_size = self.v.patch_embed.img_size
if verbose == True:
print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
print('number of patches={:d}'.format(num_patches))
new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, 1212, 768).transpose(1, 2).reshape(1, 768, 12, 101)
# if the input sequence length is larger than the original audioset (10s), then cut the positional embedding
if t_dim < 101:
new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim]
# otherwise interpolate
else:
new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear')
if f_dim < 12:
new_pos_embed = new_pos_embed[:, :, 6 - int(f_dim/2): 6 - int(f_dim/2) + f_dim, :]
# otherwise interpolate
elif f_dim > 12:
new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
self.return_hidden_state = return_hidden_state
def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024):
test_input = torch.randn(1, 1, input_fdim, input_tdim)
test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
test_out = test_proj(test_input)
f_dim = test_out.shape[2]
t_dim = test_out.shape[3]
return f_dim, t_dim
@autocast()
def forward(self, x, external_features=None):
"""
:param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
:return: prediction
"""
# expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
all_hidden_states = () if self.return_hidden_state else None
x = x.unsqueeze(1)
x = x.transpose(2, 3)
B = x.shape[0]
x = self.v.patch_embed(x)
cls_tokens = self.v.cls_token.expand(B, -1, -1)
dist_token = self.v.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.v.pos_embed
x = self.v.pos_drop(x)
for blk in self.v.blocks:
x = blk(x)
if self.return_hidden_state:
all_hidden_states = all_hidden_states + (self.v.norm(x),)
x = self.v.norm(x)
# x[:, 0] = (x[:, 0] + x[:, 1]) / 2
# x = (x[:, 0] + x[:, 1]) / 2
# x = self.mlp_head(x)
if self.return_hidden_state:
return x, all_hidden_states
else:
return x
if __name__ == '__main__':
input_tdim = 100
ast_mdl = ASTModel(input_tdim=input_tdim)
# input a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, 128])
test_output = ast_mdl(test_input)
# output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes.
print(test_output.shape)
input_tdim = 256
ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
# input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, 128])
test_output = ast_mdl(test_input)
# output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
print(test_output.shape) |