Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
•
3a0f0a5
1
Parent(s):
b6800ef
added decision tree
Browse files- environment.yml +1 -0
- models/audio_spectrogram_transformer.py +48 -1
- models/config/train_local.yaml +47 -0
- models/decision_tree.py +124 -0
- train.py +26 -1
environment.yml
CHANGED
@@ -9,6 +9,7 @@ dependencies:
|
|
9 |
- pytorch
|
10 |
- torchaudio
|
11 |
- torchvision
|
|
|
12 |
- numpy
|
13 |
- pandas
|
14 |
- seaborn
|
|
|
9 |
- pytorch
|
10 |
- torchaudio
|
11 |
- torchvision
|
12 |
+
- librosa
|
13 |
- numpy
|
14 |
- pandas
|
15 |
- seaborn
|
models/audio_spectrogram_transformer.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from transformers import
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
from sklearn.utils.class_weight import compute_class_weight
|
@@ -7,6 +7,53 @@ import numpy as np
|
|
7 |
|
8 |
accuracy = evaluate.load("accuracy")
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def compute_metrics(eval_pred):
|
11 |
predictions = np.argmax(eval_pred.predictions, axis=1)
|
12 |
return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
|
|
|
1 |
+
from transformers import ASTModel, AutoFeatureExtractor, ASTConfig, AutoModelForAudioClassification, TrainingArguments, Trainer
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
from sklearn.utils.class_weight import compute_class_weight
|
|
|
7 |
|
8 |
accuracy = evaluate.load("accuracy")
|
9 |
|
10 |
+
|
11 |
+
class MultiModalAST(nn.Module):
|
12 |
+
|
13 |
+
|
14 |
+
def __init__(self, labels, sample_rate, *args, **kwargs) -> None:
|
15 |
+
super().__init__(*args, **kwargs)
|
16 |
+
id2label, label2id = get_id_label_mapping(labels)
|
17 |
+
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
18 |
+
self.ast_feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
19 |
+
|
20 |
+
self.ast_model = ASTModel.from_pretrained(
|
21 |
+
model_checkpoint,
|
22 |
+
num_labels=len(label2id),
|
23 |
+
label2id=label2id,
|
24 |
+
id2label=id2label,
|
25 |
+
ignore_mismatched_sizes=True
|
26 |
+
)
|
27 |
+
self.sample_rate = sample_rate
|
28 |
+
|
29 |
+
self.bpm_model = nn.Sequential(
|
30 |
+
nn.Linear(len(labels), 100),
|
31 |
+
nn.Linear(100, 50)
|
32 |
+
)
|
33 |
+
|
34 |
+
out_dim = 50 # TODO: Calculate output dimension
|
35 |
+
self.classifier = nn.Sequential(
|
36 |
+
nn.Linear(out_dim, 100),
|
37 |
+
nn.Linear(100, len(labels))
|
38 |
+
)
|
39 |
+
|
40 |
+
def vectorize_bpm(self, waveform):
|
41 |
+
pass
|
42 |
+
|
43 |
+
|
44 |
+
def forward(self, audio):
|
45 |
+
|
46 |
+
bpm_vector = self.vectorize_bpm(audio)
|
47 |
+
bpm_out = self.bpm_model(bpm_vector)
|
48 |
+
|
49 |
+
spectrogram = self.ast_feature_extractor(audio)
|
50 |
+
ast_out = self.ast_model(spectrogram)
|
51 |
+
|
52 |
+
# Late fusion
|
53 |
+
z = torch.cat([ast_out, bpm_out]) # Which dimension?
|
54 |
+
return self.classifier(z)
|
55 |
+
|
56 |
+
|
57 |
def compute_metrics(eval_pred):
|
58 |
predictions = np.argmax(eval_pred.predictions, axis=1)
|
59 |
return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
|
models/config/train_local.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
global:
|
2 |
+
id: decision_tree
|
3 |
+
device: mps
|
4 |
+
seed: 42
|
5 |
+
dance_ids:
|
6 |
+
- ATN
|
7 |
+
- BCH
|
8 |
+
- CHA
|
9 |
+
- ECS
|
10 |
+
- HST
|
11 |
+
- JIV
|
12 |
+
- QST
|
13 |
+
- RMB
|
14 |
+
- SFT
|
15 |
+
- SLS
|
16 |
+
- SMB
|
17 |
+
- SWZ
|
18 |
+
- TGO
|
19 |
+
- VWZ
|
20 |
+
- WCS
|
21 |
+
data_module:
|
22 |
+
song_data_path: data/songs_cleaned.csv
|
23 |
+
song_audio_path: data/samples
|
24 |
+
batch_size: 32
|
25 |
+
num_workers: 4
|
26 |
+
min_votes: 1
|
27 |
+
dataset_kwargs:
|
28 |
+
audio_window_duration: 6
|
29 |
+
audio_window_jitter: 1.5
|
30 |
+
audio_pipeline_kwargs:
|
31 |
+
mask_count: 0 # Don't mask the data
|
32 |
+
snr_mean: 15.0 # Pretty much eliminate the noise
|
33 |
+
freq_mask_size: 10
|
34 |
+
time_mask_size: 80
|
35 |
+
|
36 |
+
trainer:
|
37 |
+
log_every_n_steps: 15
|
38 |
+
accelerator: gpu
|
39 |
+
max_epochs: 50
|
40 |
+
min_epochs: 5
|
41 |
+
fast_dev_run: False
|
42 |
+
# gradient_clip_val: 0.5
|
43 |
+
# overfit_batches: 1
|
44 |
+
training_environment:
|
45 |
+
learning_rate: 0.00053
|
46 |
+
model:
|
47 |
+
n_channels: 128
|
models/decision_tree.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from sklearn.base import ClassifierMixin, BaseEstimator
|
3 |
+
import pandas as pd
|
4 |
+
from torch import nn
|
5 |
+
import torch
|
6 |
+
from typing import Iterator
|
7 |
+
import numpy as np
|
8 |
+
import json
|
9 |
+
from tqdm import tqdm
|
10 |
+
import librosa
|
11 |
+
DANCE_INFO_FILE = "data/dance_info.csv"
|
12 |
+
dance_info_df = pd.read_csv(DANCE_INFO_FILE, converters={'tempoRange': lambda s : json.loads(s.replace("'", '"'))})
|
13 |
+
|
14 |
+
class DanceTreeClassifier(BaseEstimator,ClassifierMixin):
|
15 |
+
"""
|
16 |
+
Trains a series of binary classifiers to classify each dance when a song falls into its bpm range.
|
17 |
+
|
18 |
+
Features:
|
19 |
+
- Spectrogram
|
20 |
+
- BPM
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, device="cpu", lr=1e-4, epochs=5, verbose=True) -> None:
|
24 |
+
self.device=device
|
25 |
+
self.epochs=epochs
|
26 |
+
self.verbose = verbose
|
27 |
+
self.lr = lr
|
28 |
+
self.classifiers = {}
|
29 |
+
self.optimizers = {}
|
30 |
+
self.criterion = nn.BCELoss()
|
31 |
+
|
32 |
+
def get_valid_dances_from_bpm(self,bpm:float) -> list[str]:
|
33 |
+
mask = dance_info_df["tempoRange"].apply(lambda interval: interval["min"] <= bpm <= interval["max"])
|
34 |
+
return list(dance_info_df["id"][mask])
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
def fit(self, x, y):
|
39 |
+
"""
|
40 |
+
x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time)
|
41 |
+
y: (batch_size, n_classes)
|
42 |
+
"""
|
43 |
+
progress_bar = tqdm(range(self.epochs))
|
44 |
+
for _ in progress_bar:
|
45 |
+
# TODO: Introduce batches
|
46 |
+
epoch_loss = 0
|
47 |
+
pred_count = 0
|
48 |
+
for (spec, bpm), label in zip(x, y):
|
49 |
+
# find all models that are in the bpm range
|
50 |
+
matching_dances = self.get_valid_dances_from_bpm(bpm)
|
51 |
+
for dance in matching_dances:
|
52 |
+
if dance not in self.classifiers or dance not in self.optimizers:
|
53 |
+
classifier = DanceCNN()
|
54 |
+
self.classifiers[dance] = classifier
|
55 |
+
self.optimizers[dance] = torch.optim.Adam(classifier.parameters(), lr=self.lr)
|
56 |
+
models = [(dance, model, self.optimizers[dance]) for dance, model in self.classifiers.items() if dance in matching_dances]
|
57 |
+
for dance, model,opt in models:
|
58 |
+
opt.zero_grad()
|
59 |
+
spec = torch.from_numpy(spec).to(self.device)
|
60 |
+
output = model(spec)
|
61 |
+
target = torch.tensor(float(dance == label))
|
62 |
+
loss = self.criterion(output, target)
|
63 |
+
epoch_loss += loss.item()
|
64 |
+
pred_count +=1
|
65 |
+
loss.backward()
|
66 |
+
opt.step()
|
67 |
+
progress_bar.set_description(f"Loss: {epoch_loss / pred_count}")
|
68 |
+
|
69 |
+
def predict(self, x) -> list[str]:
|
70 |
+
results = []
|
71 |
+
for spec, bpm in zip(*x):
|
72 |
+
matching_dances = self.get_valid_dances_from_bpm(bpm)
|
73 |
+
dance_i = torch.tensor([self.classifiers[dance](spec) for dance in matching_dances]).argmax()
|
74 |
+
results.append(matching_dances[dance_i])
|
75 |
+
return results
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
class DanceCNN(nn.Module):
|
81 |
+
def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None:
|
82 |
+
super().__init__(*args, **kwargs)
|
83 |
+
kernel_size=(3,9)
|
84 |
+
self.cnn = nn.Sequential(
|
85 |
+
nn.Conv2d(1,16, kernel_size=kernel_size),
|
86 |
+
nn.ReLU(),
|
87 |
+
nn.MaxPool2d((2,10)),
|
88 |
+
nn.Conv2d(16,32, kernel_size=kernel_size),
|
89 |
+
nn.ReLU(),
|
90 |
+
nn.MaxPool2d((2,10))
|
91 |
+
)
|
92 |
+
|
93 |
+
embedding_dimension = 32* 3 * 959
|
94 |
+
self.classifier = nn.Sequential(
|
95 |
+
nn.Linear(embedding_dimension, 200),
|
96 |
+
nn.ReLU(),
|
97 |
+
nn.Linear(200, 1),
|
98 |
+
nn.Sigmoid()
|
99 |
+
)
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
x = self.cnn(x)
|
103 |
+
x = x.flatten() if len(x.shape) == 3 else x.flatten(1)
|
104 |
+
return self.classifier(x)
|
105 |
+
|
106 |
+
def features_from_path(paths:list[str],
|
107 |
+
audio_window_duration=6,
|
108 |
+
audio_duration=30,
|
109 |
+
resample_freq=16000) -> Iterator[tuple[np.array, float]]:
|
110 |
+
"""
|
111 |
+
Loads audio and bpm from an audio path.
|
112 |
+
"""
|
113 |
+
|
114 |
+
for path in paths:
|
115 |
+
waveform, sr = librosa.load(path, mono=True, sr=resample_freq)
|
116 |
+
num_frames = audio_window_duration * sr
|
117 |
+
tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr)
|
118 |
+
mfccs = librosa.feature.mfcc(y=waveform, sr=sr, n_mfcc=20)
|
119 |
+
mfccs_normalized = (mfccs - mfccs.mean()) / mfccs.std()
|
120 |
+
mfccs_padded = librosa.util.fix_length(mfccs_normalized, size=sr*audio_duration, axis=1)
|
121 |
+
mfccs_reshaped = mfccs_padded.reshape(1, mfccs_padded.shape[0], mfccs_padded.shape[1])
|
122 |
+
for i in range(audio_duration//audio_window_duration):
|
123 |
+
mfcc_window = mfccs_reshaped[:,:,i*num_frames:(i+1)*num_frames]
|
124 |
+
yield (mfcc_window, tempo)
|
train.py
CHANGED
@@ -11,6 +11,7 @@ from models.audio_spectrogram_transformer import train as train_audio_spectrogra
|
|
11 |
from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
|
12 |
from preprocessing.preprocess import get_examples
|
13 |
from models.residual import ResidualDancer, TrainingEnvironment
|
|
|
14 |
import yaml
|
15 |
from preprocessing.dataset import DanceDataModule, WaveformSongDataset, HuggingFaceWaveformSongDataset
|
16 |
from torch.utils.data import random_split
|
@@ -32,6 +33,8 @@ def get_training_fn(id:str) -> Callable:
|
|
32 |
return train_ast
|
33 |
case "residual_dancer":
|
34 |
return train_model
|
|
|
|
|
35 |
case _:
|
36 |
raise Exception(f"Couldn't find a training function for '{id}'.")
|
37 |
|
@@ -143,9 +146,31 @@ def train_ast_lightning(config:dict):
|
|
143 |
trainer.fit(train_env, datamodule=data)
|
144 |
trainer.test(train_env, datamodule=data)
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
if __name__ == "__main__":
|
147 |
parser = ArgumentParser(description="Trains models on the dance dataset and saves weights.")
|
148 |
-
parser.add_argument("--config",
|
|
|
|
|
149 |
args = parser.parse_args()
|
150 |
config = get_config(args.config)
|
151 |
training_id = config["global"]["id"]
|
|
|
11 |
from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
|
12 |
from preprocessing.preprocess import get_examples
|
13 |
from models.residual import ResidualDancer, TrainingEnvironment
|
14 |
+
from models.decision_tree import DanceTreeClassifier, features_from_path
|
15 |
import yaml
|
16 |
from preprocessing.dataset import DanceDataModule, WaveformSongDataset, HuggingFaceWaveformSongDataset
|
17 |
from torch.utils.data import random_split
|
|
|
33 |
return train_ast
|
34 |
case "residual_dancer":
|
35 |
return train_model
|
36 |
+
case "decision_tree":
|
37 |
+
return train_decision_tree
|
38 |
case _:
|
39 |
raise Exception(f"Couldn't find a training function for '{id}'.")
|
40 |
|
|
|
146 |
trainer.fit(train_env, datamodule=data)
|
147 |
trainer.test(train_env, datamodule=data)
|
148 |
|
149 |
+
|
150 |
+
def train_decision_tree(config:dict):
|
151 |
+
TARGET_CLASSES = config["global"]["dance_ids"]
|
152 |
+
DEVICE = config["global"]["device"]
|
153 |
+
SEED = config["global"]["seed"]
|
154 |
+
song_data_path=config['data_module']["song_data_path"]
|
155 |
+
song_audio_path = config['data_module']["song_audio_path"]
|
156 |
+
pl.seed_everything(SEED, workers=True)
|
157 |
+
|
158 |
+
df = pd.read_csv(song_data_path)
|
159 |
+
x, y = get_examples(df, song_audio_path,class_list=TARGET_CLASSES, multi_label=True)
|
160 |
+
# Convert y back to string classes
|
161 |
+
y = np.array(TARGET_CLASSES)[y.argmax(-1)]
|
162 |
+
train_i, test_i = random_split(np.arange(len(x)), [0.8, 0.2])
|
163 |
+
train_paths, train_y = x[train_i], y[train_i]
|
164 |
+
train_x = features_from_path(train_paths)
|
165 |
+
model = DanceTreeClassifier(device=DEVICE)
|
166 |
+
model.fit(train_x, train_y)
|
167 |
+
model.save()
|
168 |
+
|
169 |
if __name__ == "__main__":
|
170 |
parser = ArgumentParser(description="Trains models on the dance dataset and saves weights.")
|
171 |
+
parser.add_argument("--config",
|
172 |
+
help="Path to the yaml file that defines the training configuration.",
|
173 |
+
default="models/config/train.yaml")
|
174 |
args = parser.parse_args()
|
175 |
config = get_config(args.config)
|
176 |
training_id = config["global"]["id"]
|