re-organization
Browse files- sanity_check_result_audiomae.png → .fig/sanity_check_result_audiomae.png +0 -0
- .gitignore +1 -0
- .sample_sound/baby_coughing.wav +0 -0
- README.md +1 -1
- config.py +16 -16
- model.py +112 -112
- save_audioMAE.ipynb +0 -0
sanity_check_result_audiomae.png → .fig/sanity_check_result_audiomae.png
RENAMED
File without changes
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
.sample_sound/baby_coughing.wav
ADDED
Binary file (882 kB). View file
|
|
README.md
CHANGED
@@ -32,7 +32,7 @@ The input audio is 10s, containing baby coughing, hiccuping, and adult sneezing.
|
|
32 |
The latent dimension size of $z$ is reduced to 8 using PCA for visualization.
|
33 |
|
34 |
<p align="center">
|
35 |
-
<img src="sanity_check_result_audiomae.png" alt="" width=100%>
|
36 |
</p>
|
37 |
|
38 |
The result shows that the presence of labeled sound is clearly captured in the 3rd principal component (PC).
|
|
|
32 |
The latent dimension size of $z$ is reduced to 8 using PCA for visualization.
|
33 |
|
34 |
<p align="center">
|
35 |
+
<img src=".fig/sanity_check_result_audiomae.png" alt="" width=100%>
|
36 |
</p>
|
37 |
|
38 |
The result shows that the presence of labeled sound is clearly captured in the 3rd principal component (PC).
|
config.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
-
from transformers import PretrainedConfig
|
2 |
-
from typing import Tuple
|
3 |
-
|
4 |
-
|
5 |
-
class AudioMAEConfig(PretrainedConfig):
|
6 |
-
model_type = "audiomae"
|
7 |
-
|
8 |
-
def __init__(self,
|
9 |
-
img_size:Tuple[int,int]=(1024,128),
|
10 |
-
in_chans:int=1,
|
11 |
-
num_classes:int=0,
|
12 |
-
**kwargs,):
|
13 |
-
super().__init__(**kwargs)
|
14 |
-
self.img_size = img_size
|
15 |
-
self.in_chans = in_chans
|
16 |
-
self.num_classes = num_classes
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
|
5 |
+
class AudioMAEConfig(PretrainedConfig):
|
6 |
+
model_type = "audiomae"
|
7 |
+
|
8 |
+
def __init__(self,
|
9 |
+
img_size:Tuple[int,int]=(1024,128),
|
10 |
+
in_chans:int=1,
|
11 |
+
num_classes:int=0,
|
12 |
+
**kwargs,):
|
13 |
+
super().__init__(**kwargs)
|
14 |
+
self.img_size = img_size
|
15 |
+
self.in_chans = in_chans
|
16 |
+
self.num_classes = num_classes
|
model.py
CHANGED
@@ -1,112 +1,112 @@
|
|
1 |
-
import torch
|
2 |
-
import torchaudio
|
3 |
-
import torchaudio.transforms as transforms
|
4 |
-
from torchaudio.compliance import kaldi
|
5 |
-
|
6 |
-
from einops import rearrange
|
7 |
-
|
8 |
-
from timm.models.vision_transformer import VisionTransformer
|
9 |
-
from transformers import PreTrainedModel
|
10 |
-
|
11 |
-
from config import AudioMAEConfig
|
12 |
-
|
13 |
-
|
14 |
-
class AudioMAEEncoder(VisionTransformer):
|
15 |
-
def __init__(self, *args, **kwargs):
|
16 |
-
super().__init__(*args, **kwargs)
|
17 |
-
"""
|
18 |
-
- img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
|
19 |
-
- AudoMAE accepts a mono-channel (i.e., in_chans=1)
|
20 |
-
"""
|
21 |
-
self.MEAN = -4.2677393 # written on the paper
|
22 |
-
self.STD = 4.5689974 # written on the paper
|
23 |
-
|
24 |
-
def load_wav_file(self, file_path:str):
|
25 |
-
"""
|
26 |
-
to use this, `torchaudio` and `ffmpeg` must be installed
|
27 |
-
- `ffmpeg` version must be >=4.4 and <7.
|
28 |
-
- `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1`
|
29 |
-
"""
|
30 |
-
audio, sample_rate = torchaudio.load(file_path) # audio: (n_channels, length);
|
31 |
-
|
32 |
-
# length clip
|
33 |
-
audio_len = audio.shape[-1] / sample_rate
|
34 |
-
if audio_len > 10.0:
|
35 |
-
print('[WARNING] AudioMAE only accepts audio length up to 10s. The audio frames exceeding 10s will be clipped.')
|
36 |
-
|
37 |
-
# Check if the audio has multiple channels
|
38 |
-
if audio.shape[0] > 1:
|
39 |
-
# Convert stereo audio to mono by taking the mean across channels
|
40 |
-
# AudioMAE accepts a mono channel.
|
41 |
-
audio = torch.mean(audio, dim=0, keepdim=True)
|
42 |
-
|
43 |
-
# resample the audio into 16khz
|
44 |
-
# AudioMAE accepts 16khz
|
45 |
-
if sample_rate != 16000:
|
46 |
-
converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000)
|
47 |
-
audio = converter(audio)
|
48 |
-
return audio
|
49 |
-
|
50 |
-
def waveform_to_melspec(self, waveform:torch.FloatTensor):
|
51 |
-
# Compute the Mel spectrogram using Kaldi-compatible features
|
52 |
-
# the parameters are chosen as described in the audioMAE paper (4.2 implementation details)
|
53 |
-
mel_spectrogram = kaldi.fbank(
|
54 |
-
waveform,
|
55 |
-
num_mel_bins=128,
|
56 |
-
frame_length=25.0,
|
57 |
-
frame_shift=10.0,
|
58 |
-
htk_compat=True,
|
59 |
-
use_energy=False,
|
60 |
-
sample_frequency=16000,
|
61 |
-
window_type='hanning',
|
62 |
-
dither=0.0
|
63 |
-
)
|
64 |
-
|
65 |
-
# Ensure the output shape matches 1x1024x128 by padding or trimming the time dimension
|
66 |
-
expected_frames = 1024 # as described in the paper
|
67 |
-
current_frames = mel_spectrogram.shape[0]
|
68 |
-
if current_frames > expected_frames:
|
69 |
-
mel_spectrogram = mel_spectrogram[:expected_frames, :]
|
70 |
-
elif current_frames < expected_frames:
|
71 |
-
padding = expected_frames - current_frames
|
72 |
-
mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0, # (left, right) for the 1st dim
|
73 |
-
0, padding), # (left, right) for the 2nd dim
|
74 |
-
)
|
75 |
-
|
76 |
-
# scale
|
77 |
-
# as in the AudioMAE implementation [REF: https://github.com/facebookresearch/AudioMAE/blob/bd60e29651285f80d32a6405082835ad26e6f19f/dataset.py#L300]
|
78 |
-
mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2) # (length, n_freq_bins) = (1024, 128)
|
79 |
-
return mel_spectrogram
|
80 |
-
|
81 |
-
@torch.no_grad()
|
82 |
-
def encode(self, file_path:str):
|
83 |
-
self.eval()
|
84 |
-
|
85 |
-
waveform = self.load_wav_file(file_path)
|
86 |
-
melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
|
87 |
-
melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
|
88 |
-
z = self.forward_features(melspec) # (b, 1+n, d); d=768
|
89 |
-
z = z[:,1:,:] # (b n d); remove [CLS], the class token
|
90 |
-
|
91 |
-
b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
|
92 |
-
wprime = round(w / self.patch_embed.patch_size[0]) # width in the latent space
|
93 |
-
hprime = round(h / self.patch_embed.patch_size[1]) # height in the latent space
|
94 |
-
|
95 |
-
# reconstruct the temporal and freq dims
|
96 |
-
z = rearrange(z, 'b (w h) d -> b d h w', h=hprime) # (b d h' w')
|
97 |
-
|
98 |
-
# remove the batch dim
|
99 |
-
z = z[0] # (d h' w')
|
100 |
-
return z # (d h' w')
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
class PretrainedAudioMAEEncoder(PreTrainedModel):
|
105 |
-
config_class = AudioMAEConfig
|
106 |
-
|
107 |
-
def __init__(self, config):
|
108 |
-
super().__init__(config)
|
109 |
-
self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
|
110 |
-
|
111 |
-
def forward(self, file_path:str):
|
112 |
-
return self.encoder.encode(file_path) # (d h' w')
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import torchaudio.transforms as transforms
|
4 |
+
from torchaudio.compliance import kaldi
|
5 |
+
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from timm.models.vision_transformer import VisionTransformer
|
9 |
+
from transformers import PreTrainedModel
|
10 |
+
|
11 |
+
from config import AudioMAEConfig
|
12 |
+
|
13 |
+
|
14 |
+
class AudioMAEEncoder(VisionTransformer):
|
15 |
+
def __init__(self, *args, **kwargs):
|
16 |
+
super().__init__(*args, **kwargs)
|
17 |
+
"""
|
18 |
+
- img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
|
19 |
+
- AudoMAE accepts a mono-channel (i.e., in_chans=1)
|
20 |
+
"""
|
21 |
+
self.MEAN = -4.2677393 # written on the paper
|
22 |
+
self.STD = 4.5689974 # written on the paper
|
23 |
+
|
24 |
+
def load_wav_file(self, file_path:str):
|
25 |
+
"""
|
26 |
+
to use this, `torchaudio` and `ffmpeg` must be installed
|
27 |
+
- `ffmpeg` version must be >=4.4 and <7.
|
28 |
+
- `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1`
|
29 |
+
"""
|
30 |
+
audio, sample_rate = torchaudio.load(file_path) # audio: (n_channels, length);
|
31 |
+
|
32 |
+
# length clip
|
33 |
+
audio_len = audio.shape[-1] / sample_rate
|
34 |
+
if audio_len > 10.0:
|
35 |
+
print('[WARNING] AudioMAE only accepts audio length up to 10s. The audio frames exceeding 10s will be clipped.')
|
36 |
+
|
37 |
+
# Check if the audio has multiple channels
|
38 |
+
if audio.shape[0] > 1:
|
39 |
+
# Convert stereo audio to mono by taking the mean across channels
|
40 |
+
# AudioMAE accepts a mono channel.
|
41 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
42 |
+
|
43 |
+
# resample the audio into 16khz
|
44 |
+
# AudioMAE accepts 16khz
|
45 |
+
if sample_rate != 16000:
|
46 |
+
converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000)
|
47 |
+
audio = converter(audio)
|
48 |
+
return audio
|
49 |
+
|
50 |
+
def waveform_to_melspec(self, waveform:torch.FloatTensor):
|
51 |
+
# Compute the Mel spectrogram using Kaldi-compatible features
|
52 |
+
# the parameters are chosen as described in the audioMAE paper (4.2 implementation details)
|
53 |
+
mel_spectrogram = kaldi.fbank(
|
54 |
+
waveform,
|
55 |
+
num_mel_bins=128,
|
56 |
+
frame_length=25.0,
|
57 |
+
frame_shift=10.0,
|
58 |
+
htk_compat=True,
|
59 |
+
use_energy=False,
|
60 |
+
sample_frequency=16000,
|
61 |
+
window_type='hanning',
|
62 |
+
dither=0.0
|
63 |
+
)
|
64 |
+
|
65 |
+
# Ensure the output shape matches 1x1024x128 by padding or trimming the time dimension
|
66 |
+
expected_frames = 1024 # as described in the paper
|
67 |
+
current_frames = mel_spectrogram.shape[0]
|
68 |
+
if current_frames > expected_frames:
|
69 |
+
mel_spectrogram = mel_spectrogram[:expected_frames, :]
|
70 |
+
elif current_frames < expected_frames:
|
71 |
+
padding = expected_frames - current_frames
|
72 |
+
mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0, # (left, right) for the 1st dim
|
73 |
+
0, padding), # (left, right) for the 2nd dim
|
74 |
+
)
|
75 |
+
|
76 |
+
# scale
|
77 |
+
# as in the AudioMAE implementation [REF: https://github.com/facebookresearch/AudioMAE/blob/bd60e29651285f80d32a6405082835ad26e6f19f/dataset.py#L300]
|
78 |
+
mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2) # (length, n_freq_bins) = (1024, 128)
|
79 |
+
return mel_spectrogram
|
80 |
+
|
81 |
+
@torch.no_grad()
|
82 |
+
def encode(self, file_path:str):
|
83 |
+
self.eval()
|
84 |
+
|
85 |
+
waveform = self.load_wav_file(file_path)
|
86 |
+
melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
|
87 |
+
melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
|
88 |
+
z = self.forward_features(melspec) # (b, 1+n, d); d=768
|
89 |
+
z = z[:,1:,:] # (b n d); remove [CLS], the class token
|
90 |
+
|
91 |
+
b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
|
92 |
+
wprime = round(w / self.patch_embed.patch_size[0]) # width in the latent space
|
93 |
+
hprime = round(h / self.patch_embed.patch_size[1]) # height in the latent space
|
94 |
+
|
95 |
+
# reconstruct the temporal and freq dims
|
96 |
+
z = rearrange(z, 'b (w h) d -> b d h w', h=hprime) # (b d h' w')
|
97 |
+
|
98 |
+
# remove the batch dim
|
99 |
+
z = z[0] # (d h' w')
|
100 |
+
return z # (d h' w')
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
class PretrainedAudioMAEEncoder(PreTrainedModel):
|
105 |
+
config_class = AudioMAEConfig
|
106 |
+
|
107 |
+
def __init__(self, config):
|
108 |
+
super().__init__(config)
|
109 |
+
self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
|
110 |
+
|
111 |
+
def forward(self, file_path:str):
|
112 |
+
return self.encoder.encode(file_path) # (d h' w')
|
save_audioMAE.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|