Spaces:
Runtime error
Runtime error
import pytorch_lightning as pl | |
from sklearn.base import ClassifierMixin, BaseEstimator | |
import pandas as pd | |
from torch import nn | |
import torch | |
from typing import Iterator | |
import numpy as np | |
import json | |
from torch.utils.data import random_split | |
from tqdm import tqdm | |
import librosa | |
from joblib import dump, load | |
from os import path | |
import os | |
from preprocessing.dataset import get_music4dance_examples | |
DANCE_INFO_FILE = "data/dance_info.csv" | |
dance_info_df = pd.read_csv( | |
DANCE_INFO_FILE, | |
converters={"tempoRange": lambda s: json.loads(s.replace("'", '"'))}, | |
) | |
class DanceTreeClassifier(BaseEstimator, ClassifierMixin): | |
""" | |
Trains a series of binary classifiers to classify each dance when a song falls into its bpm range. | |
Features: | |
- Spectrogram | |
- BPM | |
""" | |
def __init__(self, device="cpu", lr=1e-4, verbose=True) -> None: | |
self.device = device | |
self.verbose = verbose | |
self.lr = lr | |
self.classifiers = {} | |
self.optimizers = {} | |
self.criterion = nn.BCELoss() | |
def get_valid_dances_from_bpm(self, bpm: float) -> list[str]: | |
mask = dance_info_df["tempoRange"].apply( | |
lambda interval: interval["min"] <= bpm <= interval["max"] | |
) | |
return list(dance_info_df["id"][mask]) | |
def fit(self, x, y): | |
""" | |
x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time) | |
y: (batch_size, n_classes) | |
""" | |
epoch_loss = 0 | |
pred_count = 0 | |
data_loader = zip(x, y) | |
if self.verbose: | |
data_loader = tqdm(data_loader, total=len(y)) | |
for (spec, bpm), label in data_loader: | |
# find all models that are in the bpm range | |
matching_dances = self.get_valid_dances_from_bpm(bpm) | |
spec = torch.from_numpy(spec).to(self.device) | |
for dance in matching_dances: | |
if dance not in self.classifiers or dance not in self.optimizers: | |
classifier = DanceCNN().to(self.device) | |
self.classifiers[dance] = classifier | |
self.optimizers[dance] = torch.optim.Adam( | |
classifier.parameters(), lr=self.lr | |
) | |
models = [ | |
(dance, model, self.optimizers[dance]) | |
for dance, model in self.classifiers.items() | |
if dance in matching_dances | |
] | |
for model_i, (dance, model, opt) in enumerate(models, start=1): | |
opt.zero_grad() | |
output = model(spec) | |
target = torch.tensor([float(dance == label)], device=self.device) | |
loss = self.criterion(output, target) | |
epoch_loss += loss.item() | |
pred_count += 1 | |
loss.backward() | |
if self.verbose: | |
data_loader.set_description( | |
f"model: {model_i}/{len(models)}, loss: {loss.item()}" | |
) | |
opt.step() | |
def predict(self, x) -> list[str]: | |
results = [] | |
for spec, bpm in zip(*x): | |
matching_dances = self.get_valid_dances_from_bpm(bpm) | |
dance_i = torch.tensor( | |
[self.classifiers[dance](spec) for dance in matching_dances] | |
).argmax() | |
results.append(matching_dances[dance_i]) | |
return results | |
def save(self, folder: str): | |
# Create a folder | |
classifier_path = path.join(folder, "classifier") | |
os.makedirs(classifier_path, exist_ok=True) | |
# Swap out model reference | |
classifiers = self.classifiers | |
optimizers = self.optimizers | |
criterion = self.criterion | |
self.classifiers = None | |
self.optimizers = None | |
self.criterion = None | |
# Save the Pth models | |
for dance, classifier in classifiers.items(): | |
torch.save( | |
classifier.state_dict(), path.join(classifier_path, dance + ".pth") | |
) | |
# Save the Sklearn model | |
dump(path.join(folder, "sklearn.joblib")) | |
# Reload values | |
self.classifiers = classifiers | |
self.optimizers = optimizers | |
self.criterion = criterion | |
def from_config(folder: str, device="cpu") -> "DanceTreeClassifier": | |
# load in weights | |
model_paths = ( | |
p for p in os.listdir(path.join(folder, "classifier")) if p.endswith("pth") | |
) | |
classifiers = {} | |
for model_path in model_paths: | |
dance = model_path.split(".")[0] | |
model = DanceCNN().to(device) | |
model.load_state_dict( | |
torch.load(path.join(folder, "classifier", model_path)) | |
) | |
classifiers[dance] = model | |
wrapper = load(path.join(folder, "sklearn.joblib")) | |
wrapper.classifiers = classifiers | |
return wrapper | |
class DanceCNN(nn.Module): | |
def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
kernel_size = (3, 9) | |
self.cnn = nn.Sequential( | |
nn.Conv2d(1, 16, kernel_size=kernel_size), | |
nn.ReLU(), | |
nn.MaxPool2d((2, 10)), | |
nn.Conv2d(16, 32, kernel_size=kernel_size), | |
nn.ReLU(), | |
nn.MaxPool2d((2, 10)), | |
nn.Conv2d(32, 32, kernel_size=kernel_size), | |
nn.ReLU(), | |
nn.MaxPool2d((2, 10)), | |
nn.Conv2d(32, 16, kernel_size=kernel_size), | |
nn.ReLU(), | |
nn.MaxPool2d((2, 10)), | |
) | |
embedding_dimension = 16 * 6 * 8 | |
self.classifier = nn.Sequential( | |
nn.Linear(embedding_dimension, 200), | |
nn.ReLU(), | |
nn.Linear(200, 1), | |
nn.Sigmoid(), | |
) | |
def forward(self, x): | |
x = self.cnn(x) | |
x = x.flatten() if len(x.shape) == 3 else x.flatten(1) | |
return self.classifier(x) | |
def features_from_path( | |
paths: list[str], audio_window_duration=6, audio_duration=30, resample_freq=16000 | |
) -> Iterator[tuple[np.array, float]]: | |
""" | |
Loads audio and bpm from an audio path. | |
""" | |
for path in paths: | |
waveform, sr = librosa.load(path, mono=True, sr=resample_freq) | |
num_frames = audio_window_duration * sr | |
tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr) | |
spec = librosa.feature.melspectrogram(y=waveform, sr=sr) | |
spec_normalized = (spec - spec.mean()) / spec.std() | |
spec_padded = librosa.util.fix_length( | |
spec_normalized, size=sr * audio_duration, axis=1 | |
) | |
batched_spec = np.expand_dims(spec_padded, axis=0) | |
for i in range(audio_duration // audio_window_duration): | |
spec_window = batched_spec[:, :, i * num_frames : (i + 1) * num_frames] | |
yield (spec_window, tempo) | |
def train_decision_tree(config: dict): | |
TARGET_CLASSES = config["global"]["dance_ids"] | |
DEVICE = config["global"]["device"] | |
SEED = config["global"]["seed"] | |
SEED = config["global"]["seed"] | |
EPOCHS = config["trainer"]["min_epochs"] | |
song_data_path = config["data_module"]["song_data_path"] | |
song_audio_path = config["data_module"]["song_audio_path"] | |
pl.seed_everything(SEED, workers=True) | |
df = pd.read_csv(song_data_path) | |
x, y = get_music4dance_examples( | |
df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True | |
) | |
# Convert y back to string classes | |
y = np.array(TARGET_CLASSES)[y.argmax(-1)] | |
train_i, test_i = random_split( | |
np.arange(len(x)), [0.1, 0.9] | |
) # Temporary to test efficacy | |
train_paths, train_y = x[train_i], y[train_i] | |
model = DanceTreeClassifier(device=DEVICE) | |
for epoch in tqdm(range(1, EPOCHS + 1)): | |
# Shuffle the data | |
i = np.arange(len(train_paths)) | |
np.random.shuffle(i) | |
train_paths = train_paths[i] | |
train_y = train_y[i] | |
train_x = features_from_path(train_paths) | |
model.fit(train_x, train_y) | |
# evaluate the model | |
preds = model.predict(x[test_i]) | |
accuracy = (preds == y[test_i]).mean() | |
print(f"{accuracy=}") | |
model.save("models/weights/decision_tree") | |