audiomae / model.py
dslee2601's picture
warning msg
c137996
from typing import Tuple
import torch
import torchaudio
import torchaudio.transforms as transforms
from torchaudio.compliance import kaldi
from transformers import PretrainedConfig
from einops import rearrange
from timm.models.vision_transformer import VisionTransformer
from transformers import PreTrainedModel
# it seems like Config class and Model class should be located in the same file; otherwise, seemingly casuing an issue in model loading after pushing to HF.
class AudioMAEConfig(PretrainedConfig):
model_type = "audiomae"
def __init__(self,
img_size:Tuple[int,int]=(1024,128),
in_chans:int=1,
num_classes:int=0,
**kwargs,):
super().__init__(**kwargs)
self.img_size = img_size
self.in_chans = in_chans
self.num_classes = num_classes
class AudioMAEEncoder(VisionTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
"""
- img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
- AudoMAE accepts a mono-channel (i.e., in_chans=1)
"""
self.MEAN = -4.2677393 # written on the paper
self.STD = 4.5689974 # written on the paper
def load_wav_file(self, file_path:str):
"""
to use this, `torchaudio` and `ffmpeg` must be installed
- `ffmpeg` version must be >=4.4 and <7.
- `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1`
"""
audio, sample_rate = torchaudio.load(file_path) # audio: (n_channels, length);
# length clip
audio_len = audio.shape[-1] / sample_rate
if audio_len > 10.0:
print('current audio length is:', audio_len)
print('[WARNING] AudioMAE only accepts audio length up to 10s. The audio frames exceeding 10s will be clipped.')
# Check if the audio has multiple channels
if audio.shape[0] > 1:
# Convert stereo audio to mono by taking the mean across channels
# AudioMAE accepts a mono channel.
audio = torch.mean(audio, dim=0, keepdim=True)
# resample the audio into 16khz
# AudioMAE accepts 16khz
if sample_rate != 16000:
converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000)
audio = converter(audio)
return audio
def waveform_to_melspec(self, waveform:torch.FloatTensor):
# Compute the Mel spectrogram using Kaldi-compatible features
# the parameters are chosen as described in the audioMAE paper (4.2 implementation details)
mel_spectrogram = kaldi.fbank(
waveform,
num_mel_bins=128,
frame_length=25.0,
frame_shift=10.0,
htk_compat=True,
use_energy=False,
sample_frequency=16000,
window_type='hanning',
dither=0.0
)
# Ensure the output shape matches 1x1024x128 by padding or trimming the time dimension
expected_frames = 1024 # as described in the paper
current_frames = mel_spectrogram.shape[0]
if current_frames > expected_frames:
mel_spectrogram = mel_spectrogram[:expected_frames, :]
elif current_frames < expected_frames:
padding = expected_frames - current_frames
mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0, # (left, right) for the 1st dim
0, padding), # (left, right) for the 2nd dim
)
# scale
# as in the AudioMAE implementation [REF: https://github.com/facebookresearch/AudioMAE/blob/bd60e29651285f80d32a6405082835ad26e6f19f/dataset.py#L300]
mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2) # (length, n_freq_bins) = (1024, 128)
return mel_spectrogram
@torch.no_grad()
def encode(self, file_path:str, device):
self.eval()
waveform = self.load_wav_file(file_path)
melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
z = self.forward_features(melspec.to(device)).cpu() # (b, 1+n, d); d=768
z = z[:,1:,:] # (b n d); remove [CLS], the class token
b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
wprime = round(w / self.patch_embed.patch_size[0]) # width in the latent space
hprime = round(h / self.patch_embed.patch_size[1]) # height in the latent space
# reconstruct the temporal and freq dims
z = rearrange(z, 'b (w h) d -> b d h w', h=hprime) # (b d h' w')
# remove the batch dim
z = z[0] # (d h' w')
return z # (d h' w')
class PretrainedAudioMAEEncoder(PreTrainedModel):
config_class = AudioMAEConfig
def __init__(self, config):
super().__init__(config)
self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
def forward(self, file_path:str):
device = self.device
return self.encoder.encode(file_path, device) # (d h' w')