|
from typing import Union, Callable, List, Optional, Dict |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.optim import Adam |
|
import numpy as np |
|
import librosa |
|
import miniaudio |
|
from pathlib import Path |
|
from sklearn.model_selection import train_test_split |
|
from tqdm import tqdm |
|
from functools import partial |
|
import math |
|
|
|
from mae import MaskedAutoencoderViT |
|
|
|
|
|
def load_audio( |
|
path: str, |
|
sr: int = 32000, |
|
duration: int = 20, |
|
) -> (np.ndarray, int): |
|
g = miniaudio.stream_file(path, output_format=miniaudio.SampleFormat.FLOAT32, nchannels=1, |
|
sample_rate=sr, frames_to_read=sr * duration) |
|
signal = np.array(next(g)) |
|
return signal |
|
|
|
|
|
def mel_spectrogram( |
|
signal: np.ndarray, |
|
sr: int = 32000, |
|
n_fft: int = 800, |
|
hop_length: int = 320, |
|
n_mels: int = 128, |
|
) -> np.ndarray: |
|
mel_spec = librosa.feature.melspectrogram( |
|
y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, |
|
window='hann', pad_mode='constant' |
|
) |
|
mel_spec = librosa.power_to_db(mel_spec) |
|
return mel_spec.T |
|
|
|
|
|
def normalize(arr: np.ndarray, eps: float = 1e-8) -> np.ndarray: |
|
return (arr - arr.mean()) / (arr.std() + eps) |
|
|
|
|
|
device = 'cuda:0' |
|
seed = 42 |
|
train_size = 0.8 |
|
batch_size_train = 10 |
|
batch_size_test = 32 |
|
num_workers = 1 |
|
lr = 1e-3 |
|
epochs = 200 |
|
detection_epoch = 20 |
|
|
|
sr = 32000 |
|
n_fft = 800 |
|
hop_length = 320 |
|
duration = 10000 |
|
|
|
feature_length = 2048 |
|
patch_size = 16 |
|
|
|
feature_padding = True |
|
header = 'mean' |
|
|
|
mlp_num_neurons = [768, 10] |
|
mlp_activation_layer = nn.ReLU |
|
mlp_bias = True |
|
|
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
mae = MaskedAutoencoderViT( |
|
img_size=(2048, 128), |
|
patch_size=16, |
|
in_chans=1, |
|
embed_dim=768, |
|
depth=12, |
|
num_heads=12, |
|
decoder_mode=1, |
|
no_shift=False, |
|
decoder_embed_dim=512, |
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
norm_pix_loss=False, |
|
pos_trainable=False, |
|
) |
|
|
|
|
|
ckpt_path = 'music-mae-32kHz.pth.pth' |
|
mae.load_state_dict(torch.load(ckpt_path, map_location='cpu')) |
|
mae.to(device) |
|
mae.eval() |
|
|
|
|
|
fp = Path('GTZAN-dataset/genres_original') |
|
audio_data = dict() |
|
|
|
for d in fp.iterdir(): |
|
if d.is_dir(): |
|
for f in d.iterdir(): |
|
if f.is_file(): |
|
genres = f.name.split('.')[0] |
|
if genres not in audio_data: |
|
audio_data[genres] = [str(f)] |
|
else: |
|
audio_data[genres].append(str(f)) |
|
|
|
audio_data_train = dict() |
|
audio_data_test = dict() |
|
|
|
for k, v in audio_data.items(): |
|
train_data, test_data = train_test_split(v, train_size=train_size, random_state=seed, shuffle=True) |
|
audio_data_train[k] = train_data |
|
audio_data_test[k] = test_data |
|
|
|
|
|
@torch.no_grad() |
|
def infer_mae_embedding(data: Dict) -> Dict: |
|
emb_data = dict() |
|
|
|
for k, v in tqdm(data.items(), desc='infer mae embedding', total=len(data)): |
|
for f in v: |
|
try: |
|
mel_spec = mel_spectrogram(load_audio(f, duration=duration), sr=sr, n_fft=n_fft, hop_length=hop_length) |
|
except Exception as e: |
|
print(e) |
|
print(f) |
|
continue |
|
|
|
|
|
input_length = mel_spec.shape[0] |
|
n = math.ceil(input_length / patch_size) |
|
if input_length < patch_size * n: |
|
pad_length = patch_size * n - input_length |
|
mel_spec = np.pad(mel_spec, ((0, pad_length), (0, 0)), mode='constant', constant_values=mel_spec.min()) |
|
|
|
|
|
|
|
input_length = mel_spec.shape[0] |
|
embeds = [] |
|
for i in range(0, input_length, feature_length): |
|
snippet = mel_spec[i:i + feature_length] |
|
snippet = normalize(snippet) |
|
snippet = snippet[None, None, :, :] |
|
x = torch.from_numpy(snippet).to(device) |
|
y = mae.forward_encoder_no_mask(x, header=header) |
|
y = y / y.norm(p=2, dim=-1, keepdim=True) |
|
y = y.cpu().numpy().squeeze() |
|
embeds.append(y) |
|
|
|
y = np.mean(embeds, axis=0) |
|
|
|
if k not in emb_data: |
|
emb_data[k] = [y] |
|
else: |
|
emb_data[k].append(y) |
|
|
|
return emb_data |
|
|
|
|
|
audio_emb_train = infer_mae_embedding(audio_data_train) |
|
audio_emb_test = infer_mae_embedding(audio_data_test) |
|
|
|
label_set = set(audio_emb_train.keys()) |
|
label_map = {label: i for i, label in enumerate(label_set)} |
|
print(label_map) |
|
|
|
|
|
class MLP(torch.nn.Sequential): |
|
def __init__( |
|
self, |
|
num_neurons: List[int], |
|
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, |
|
bias: bool = True, |
|
dropout: float = 0.0, |
|
): |
|
layers = [] |
|
for c_in, c_out in zip(num_neurons[:-1], num_neurons[1:]): |
|
layers.append(torch.nn.Linear(c_in, c_out, bias=bias)) |
|
layers.append(activation_layer()) |
|
layers.append(torch.nn.Dropout(dropout)) |
|
|
|
|
|
layers.pop() |
|
layers.pop() |
|
|
|
super().__init__(*layers) |
|
|
|
|
|
class SimpleDataset(Dataset): |
|
def __init__(self, dict_data: Dict, label_map: Dict): |
|
self.embed_with_label = [] |
|
|
|
for k, v in dict_data.items(): |
|
for emb in v: |
|
self.embed_with_label.append((emb, label_map[k])) |
|
|
|
def __len__(self): |
|
return len(self.embed_with_label) |
|
|
|
def __getitem__(self, idx): |
|
return self.embed_with_label[idx] |
|
|
|
|
|
train_dataset = SimpleDataset(audio_emb_train, label_map) |
|
test_dataset = SimpleDataset(audio_emb_test, label_map) |
|
print(f"len(train_dataset): {len(train_dataset)}") |
|
print(f"len(test_dataset): {len(test_dataset)}") |
|
|
|
|
|
def train_one_epoch(model, device, dataloader, loss_fn, optimizer): |
|
model.train() |
|
|
|
|
|
for batch in dataloader: |
|
x, y = batch |
|
x = x.to(device) |
|
y = y.to(device) |
|
|
|
y_logit = model(x) |
|
loss = loss_fn(y_logit, y) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
@torch.no_grad() |
|
def eval_one_epoch(model, device, dataloader, loss_fn): |
|
model.eval() |
|
|
|
total_loss = 0.0 |
|
total_correct = 0.0 |
|
total_num = 0.0 |
|
|
|
for batch in dataloader: |
|
x, y = batch |
|
x = x.to(device) |
|
y = y.to(device) |
|
|
|
y_logit = model(x) |
|
loss = loss_fn(y_logit, y) |
|
|
|
total_loss += loss.item() * x.shape[0] |
|
total_correct += (y_logit.argmax(dim=-1) == y).sum().item() |
|
total_num += x.shape[0] |
|
|
|
loss = total_loss / total_num |
|
acc = total_correct / total_num |
|
|
|
return loss, acc |
|
|
|
|
|
mlp = MLP( |
|
num_neurons=mlp_num_neurons, |
|
activation_layer=mlp_activation_layer, |
|
bias=mlp_bias, |
|
dropout=0.0 |
|
) |
|
print(MLP) |
|
|
|
mlp.to(device) |
|
|
|
optimizer = Adam(mlp.parameters(), lr=lr) |
|
loss_fn = nn.CrossEntropyLoss() |
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=num_workers) |
|
test_dataloader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, num_workers=num_workers) |
|
|
|
test_loss, test_accuracy = eval_one_epoch(mlp, device, test_dataloader, loss_fn) |
|
print(f"init: test loss {test_loss:.4f}, test accuracy {test_accuracy:.4f}") |
|
|
|
best_accuracy = 0.0 |
|
at = 0 |
|
|
|
for epoch in range(epochs): |
|
train_one_epoch(mlp, device, train_dataloader, loss_fn, optimizer) |
|
test_loss, test_accuracy = eval_one_epoch(mlp, device, test_dataloader, loss_fn) |
|
|
|
print(f"epoch {epoch}: test loss {test_loss:.4f}, test accuracy {test_accuracy:.4f}") |
|
|
|
if test_accuracy > best_accuracy: |
|
best_accuracy = test_accuracy |
|
at = epoch |
|
|
|
if epoch - at >= detection_epoch: |
|
print(f"early stop at epoch {epoch}") |
|
print(f"best accuracy: {best_accuracy:.4f} at epoch {at}") |
|
break |
|
|