Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Reference Repo: https://github.com/facebookresearch/AudioMAE | |
""" | |
import torch | |
import torch.nn as nn | |
from timm.models.layers import to_2tuple | |
import qa_mdt.audioldm_train.modules.audiomae.models_vit as models_vit | |
import qa_mdt.audioldm_train.modules.audiomae.models_mae as models_mae | |
# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128)) | |
class PatchEmbed_new(nn.Module): | |
"""Flexible Image to Patch Embedding""" | |
def __init__( | |
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 | |
): | |
super().__init__() | |
img_size = to_2tuple(img_size) | |
patch_size = to_2tuple(patch_size) | |
stride = to_2tuple(stride) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.proj = nn.Conv2d( | |
in_chans, embed_dim, kernel_size=patch_size, stride=stride | |
) # with overlapped patches | |
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |
# self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) | |
# self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) | |
_, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w | |
self.patch_hw = (h, w) | |
self.num_patches = h * w | |
def get_output_shape(self, img_size): | |
# todo: don't be lazy.. | |
return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape | |
def forward(self, x): | |
B, C, H, W = x.shape | |
# FIXME look at relaxing size constraints | |
# assert H == self.img_size[0] and W == self.img_size[1], \ | |
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." | |
x = self.proj(x) | |
x = x.flatten(2).transpose(1, 2) | |
return x | |
class AudioMAE(nn.Module): | |
"""Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)""" | |
def __init__( | |
self, | |
): | |
super().__init__() | |
model = models_vit.__dict__["vit_base_patch16"]( | |
num_classes=527, | |
drop_path_rate=0.1, | |
global_pool=True, | |
mask_2d=True, | |
use_custom_patch=False, | |
) | |
img_size = (1024, 128) | |
emb_dim = 768 | |
model.patch_embed = PatchEmbed_new( | |
img_size=img_size, | |
patch_size=(16, 16), | |
in_chans=1, | |
embed_dim=emb_dim, | |
stride=16, | |
) | |
num_patches = model.patch_embed.num_patches | |
# num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8 | |
model.pos_embed = nn.Parameter( | |
torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False | |
) # fixed sin-cos embedding | |
checkpoint_path = ( | |
"/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth" | |
) | |
checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
msg = model.load_state_dict(checkpoint["model"], strict=False) | |
# print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') | |
self.model = model | |
def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0): | |
""" | |
x: mel fbank [Batch, 1, T, F] | |
mask_t_prob: 'T masking ratio (percentage of removed patches).' | |
mask_f_prob: 'F masking ratio (percentage of removed patches).' | |
""" | |
return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob) | |
class Vanilla_AudioMAE(nn.Module): | |
"""Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM)""" | |
def __init__( | |
self, | |
): | |
super().__init__() | |
model = models_mae.__dict__["mae_vit_base_patch16"]( | |
in_chans=1, audio_exp=True, img_size=(1024, 128) | |
) | |
checkpoint_path = "data/checkpoints/audiomae_16k_128bins.ckpt" | |
checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
msg = model.load_state_dict(checkpoint["model"], strict=False) | |
# Skip the missing keys of decoder modules (not required) | |
# print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') | |
self.model = model.eval() | |
def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False): | |
""" | |
x: mel fbank [Batch, 1, 1024 (T), 128 (F)] | |
mask_ratio: 'masking ratio (percentage of removed patches).' | |
""" | |
with torch.no_grad(): | |
# embed: [B, 513, 768] for mask_ratio=0.0 | |
if no_mask: | |
if no_average: | |
raise RuntimeError("This function is deprecated") | |
embed = self.model.forward_encoder_no_random_mask_no_average( | |
x | |
) # mask_ratio | |
else: | |
embed = self.model.forward_encoder_no_mask(x) # mask_ratio | |
else: | |
raise RuntimeError("This function is deprecated") | |
embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio) | |
return embed | |
if __name__ == "__main__": | |
model = Vanilla_AudioMAE().cuda() | |
input = torch.randn(4, 1, 1024, 128).cuda() | |
print("The first run") | |
embed = model(input, mask_ratio=0.0, no_mask=True) | |
print(embed) | |
print("The second run") | |
embed = model(input, mask_ratio=0.0) | |
print(embed) | |