|
from typing import Tuple
|
|
|
|
import torch
|
|
import torchaudio
|
|
import torchaudio.transforms as transforms
|
|
from torchaudio.compliance import kaldi
|
|
from transformers import PretrainedConfig
|
|
|
|
from einops import rearrange
|
|
|
|
from timm.models.vision_transformer import VisionTransformer
|
|
from transformers import PreTrainedModel
|
|
|
|
|
|
|
|
class AudioMAEConfig(PretrainedConfig):
|
|
model_type = "audiomae"
|
|
|
|
def __init__(self,
|
|
img_size:Tuple[int,int]=(1024,128),
|
|
in_chans:int=1,
|
|
num_classes:int=0,
|
|
**kwargs,):
|
|
super().__init__(**kwargs)
|
|
self.img_size = img_size
|
|
self.in_chans = in_chans
|
|
self.num_classes = num_classes
|
|
|
|
|
|
class AudioMAEEncoder(VisionTransformer):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
"""
|
|
- img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
|
|
- AudoMAE accepts a mono-channel (i.e., in_chans=1)
|
|
"""
|
|
self.MEAN = -4.2677393
|
|
self.STD = 4.5689974
|
|
|
|
def load_wav_file(self, file_path:str):
|
|
"""
|
|
to use this, `torchaudio` and `ffmpeg` must be installed
|
|
- `ffmpeg` version must be >=4.4 and <7.
|
|
- `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1`
|
|
"""
|
|
audio, sample_rate = torchaudio.load(file_path)
|
|
|
|
|
|
audio_len = audio.shape[-1] / sample_rate
|
|
if audio_len > 10.0:
|
|
print('current audio length is:', audio_len)
|
|
print('[WARNING] AudioMAE only accepts audio length up to 10s. The audio frames exceeding 10s will be clipped.')
|
|
|
|
|
|
if audio.shape[0] > 1:
|
|
|
|
|
|
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
|
|
|
|
|
|
if sample_rate != 16000:
|
|
converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000)
|
|
audio = converter(audio)
|
|
return audio
|
|
|
|
def waveform_to_melspec(self, waveform:torch.FloatTensor):
|
|
|
|
|
|
mel_spectrogram = kaldi.fbank(
|
|
waveform,
|
|
num_mel_bins=128,
|
|
frame_length=25.0,
|
|
frame_shift=10.0,
|
|
htk_compat=True,
|
|
use_energy=False,
|
|
sample_frequency=16000,
|
|
window_type='hanning',
|
|
dither=0.0
|
|
)
|
|
|
|
|
|
expected_frames = 1024
|
|
current_frames = mel_spectrogram.shape[0]
|
|
if current_frames > expected_frames:
|
|
mel_spectrogram = mel_spectrogram[:expected_frames, :]
|
|
elif current_frames < expected_frames:
|
|
padding = expected_frames - current_frames
|
|
mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0,
|
|
0, padding),
|
|
)
|
|
|
|
|
|
|
|
mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2)
|
|
return mel_spectrogram
|
|
|
|
@torch.no_grad()
|
|
def encode(self, file_path:str, device):
|
|
self.eval()
|
|
|
|
waveform = self.load_wav_file(file_path)
|
|
melspec = self.waveform_to_melspec(waveform)
|
|
melspec = melspec[None,None,:,:]
|
|
z = self.forward_features(melspec.to(device)).cpu()
|
|
z = z[:,1:,:]
|
|
|
|
b, c, w, h = melspec.shape
|
|
wprime = round(w / self.patch_embed.patch_size[0])
|
|
hprime = round(h / self.patch_embed.patch_size[1])
|
|
|
|
|
|
z = rearrange(z, 'b (w h) d -> b d h w', h=hprime)
|
|
|
|
|
|
z = z[0]
|
|
return z
|
|
|
|
|
|
|
|
class PretrainedAudioMAEEncoder(PreTrainedModel):
|
|
config_class = AudioMAEConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
|
|
|
|
def forward(self, file_path:str):
|
|
device = self.device
|
|
return self.encoder.encode(file_path, device)
|
|
|