waidhoferj commited on
Commit
3a0f0a5
1 Parent(s): b6800ef

added decision tree

Browse files
environment.yml CHANGED
@@ -9,6 +9,7 @@ dependencies:
9
  - pytorch
10
  - torchaudio
11
  - torchvision
 
12
  - numpy
13
  - pandas
14
  - seaborn
 
9
  - pytorch
10
  - torchaudio
11
  - torchvision
12
+ - librosa
13
  - numpy
14
  - pandas
15
  - seaborn
models/audio_spectrogram_transformer.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import ASTFeatureExtractor, AutoFeatureExtractor, ASTConfig, AutoModelForAudioClassification, TrainingArguments, Trainer
2
  import torch
3
  from torch import nn
4
  from sklearn.utils.class_weight import compute_class_weight
@@ -7,6 +7,53 @@ import numpy as np
7
 
8
  accuracy = evaluate.load("accuracy")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def compute_metrics(eval_pred):
11
  predictions = np.argmax(eval_pred.predictions, axis=1)
12
  return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
 
1
+ from transformers import ASTModel, AutoFeatureExtractor, ASTConfig, AutoModelForAudioClassification, TrainingArguments, Trainer
2
  import torch
3
  from torch import nn
4
  from sklearn.utils.class_weight import compute_class_weight
 
7
 
8
  accuracy = evaluate.load("accuracy")
9
 
10
+
11
+ class MultiModalAST(nn.Module):
12
+
13
+
14
+ def __init__(self, labels, sample_rate, *args, **kwargs) -> None:
15
+ super().__init__(*args, **kwargs)
16
+ id2label, label2id = get_id_label_mapping(labels)
17
+ model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
18
+ self.ast_feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
19
+
20
+ self.ast_model = ASTModel.from_pretrained(
21
+ model_checkpoint,
22
+ num_labels=len(label2id),
23
+ label2id=label2id,
24
+ id2label=id2label,
25
+ ignore_mismatched_sizes=True
26
+ )
27
+ self.sample_rate = sample_rate
28
+
29
+ self.bpm_model = nn.Sequential(
30
+ nn.Linear(len(labels), 100),
31
+ nn.Linear(100, 50)
32
+ )
33
+
34
+ out_dim = 50 # TODO: Calculate output dimension
35
+ self.classifier = nn.Sequential(
36
+ nn.Linear(out_dim, 100),
37
+ nn.Linear(100, len(labels))
38
+ )
39
+
40
+ def vectorize_bpm(self, waveform):
41
+ pass
42
+
43
+
44
+ def forward(self, audio):
45
+
46
+ bpm_vector = self.vectorize_bpm(audio)
47
+ bpm_out = self.bpm_model(bpm_vector)
48
+
49
+ spectrogram = self.ast_feature_extractor(audio)
50
+ ast_out = self.ast_model(spectrogram)
51
+
52
+ # Late fusion
53
+ z = torch.cat([ast_out, bpm_out]) # Which dimension?
54
+ return self.classifier(z)
55
+
56
+
57
  def compute_metrics(eval_pred):
58
  predictions = np.argmax(eval_pred.predictions, axis=1)
59
  return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
models/config/train_local.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ id: decision_tree
3
+ device: mps
4
+ seed: 42
5
+ dance_ids:
6
+ - ATN
7
+ - BCH
8
+ - CHA
9
+ - ECS
10
+ - HST
11
+ - JIV
12
+ - QST
13
+ - RMB
14
+ - SFT
15
+ - SLS
16
+ - SMB
17
+ - SWZ
18
+ - TGO
19
+ - VWZ
20
+ - WCS
21
+ data_module:
22
+ song_data_path: data/songs_cleaned.csv
23
+ song_audio_path: data/samples
24
+ batch_size: 32
25
+ num_workers: 4
26
+ min_votes: 1
27
+ dataset_kwargs:
28
+ audio_window_duration: 6
29
+ audio_window_jitter: 1.5
30
+ audio_pipeline_kwargs:
31
+ mask_count: 0 # Don't mask the data
32
+ snr_mean: 15.0 # Pretty much eliminate the noise
33
+ freq_mask_size: 10
34
+ time_mask_size: 80
35
+
36
+ trainer:
37
+ log_every_n_steps: 15
38
+ accelerator: gpu
39
+ max_epochs: 50
40
+ min_epochs: 5
41
+ fast_dev_run: False
42
+ # gradient_clip_val: 0.5
43
+ # overfit_batches: 1
44
+ training_environment:
45
+ learning_rate: 0.00053
46
+ model:
47
+ n_channels: 128
models/decision_tree.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from sklearn.base import ClassifierMixin, BaseEstimator
3
+ import pandas as pd
4
+ from torch import nn
5
+ import torch
6
+ from typing import Iterator
7
+ import numpy as np
8
+ import json
9
+ from tqdm import tqdm
10
+ import librosa
11
+ DANCE_INFO_FILE = "data/dance_info.csv"
12
+ dance_info_df = pd.read_csv(DANCE_INFO_FILE, converters={'tempoRange': lambda s : json.loads(s.replace("'", '"'))})
13
+
14
+ class DanceTreeClassifier(BaseEstimator,ClassifierMixin):
15
+ """
16
+ Trains a series of binary classifiers to classify each dance when a song falls into its bpm range.
17
+
18
+ Features:
19
+ - Spectrogram
20
+ - BPM
21
+ """
22
+
23
+ def __init__(self, device="cpu", lr=1e-4, epochs=5, verbose=True) -> None:
24
+ self.device=device
25
+ self.epochs=epochs
26
+ self.verbose = verbose
27
+ self.lr = lr
28
+ self.classifiers = {}
29
+ self.optimizers = {}
30
+ self.criterion = nn.BCELoss()
31
+
32
+ def get_valid_dances_from_bpm(self,bpm:float) -> list[str]:
33
+ mask = dance_info_df["tempoRange"].apply(lambda interval: interval["min"] <= bpm <= interval["max"])
34
+ return list(dance_info_df["id"][mask])
35
+
36
+
37
+
38
+ def fit(self, x, y):
39
+ """
40
+ x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time)
41
+ y: (batch_size, n_classes)
42
+ """
43
+ progress_bar = tqdm(range(self.epochs))
44
+ for _ in progress_bar:
45
+ # TODO: Introduce batches
46
+ epoch_loss = 0
47
+ pred_count = 0
48
+ for (spec, bpm), label in zip(x, y):
49
+ # find all models that are in the bpm range
50
+ matching_dances = self.get_valid_dances_from_bpm(bpm)
51
+ for dance in matching_dances:
52
+ if dance not in self.classifiers or dance not in self.optimizers:
53
+ classifier = DanceCNN()
54
+ self.classifiers[dance] = classifier
55
+ self.optimizers[dance] = torch.optim.Adam(classifier.parameters(), lr=self.lr)
56
+ models = [(dance, model, self.optimizers[dance]) for dance, model in self.classifiers.items() if dance in matching_dances]
57
+ for dance, model,opt in models:
58
+ opt.zero_grad()
59
+ spec = torch.from_numpy(spec).to(self.device)
60
+ output = model(spec)
61
+ target = torch.tensor(float(dance == label))
62
+ loss = self.criterion(output, target)
63
+ epoch_loss += loss.item()
64
+ pred_count +=1
65
+ loss.backward()
66
+ opt.step()
67
+ progress_bar.set_description(f"Loss: {epoch_loss / pred_count}")
68
+
69
+ def predict(self, x) -> list[str]:
70
+ results = []
71
+ for spec, bpm in zip(*x):
72
+ matching_dances = self.get_valid_dances_from_bpm(bpm)
73
+ dance_i = torch.tensor([self.classifiers[dance](spec) for dance in matching_dances]).argmax()
74
+ results.append(matching_dances[dance_i])
75
+ return results
76
+
77
+
78
+
79
+
80
+ class DanceCNN(nn.Module):
81
+ def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None:
82
+ super().__init__(*args, **kwargs)
83
+ kernel_size=(3,9)
84
+ self.cnn = nn.Sequential(
85
+ nn.Conv2d(1,16, kernel_size=kernel_size),
86
+ nn.ReLU(),
87
+ nn.MaxPool2d((2,10)),
88
+ nn.Conv2d(16,32, kernel_size=kernel_size),
89
+ nn.ReLU(),
90
+ nn.MaxPool2d((2,10))
91
+ )
92
+
93
+ embedding_dimension = 32* 3 * 959
94
+ self.classifier = nn.Sequential(
95
+ nn.Linear(embedding_dimension, 200),
96
+ nn.ReLU(),
97
+ nn.Linear(200, 1),
98
+ nn.Sigmoid()
99
+ )
100
+
101
+ def forward(self, x):
102
+ x = self.cnn(x)
103
+ x = x.flatten() if len(x.shape) == 3 else x.flatten(1)
104
+ return self.classifier(x)
105
+
106
+ def features_from_path(paths:list[str],
107
+ audio_window_duration=6,
108
+ audio_duration=30,
109
+ resample_freq=16000) -> Iterator[tuple[np.array, float]]:
110
+ """
111
+ Loads audio and bpm from an audio path.
112
+ """
113
+
114
+ for path in paths:
115
+ waveform, sr = librosa.load(path, mono=True, sr=resample_freq)
116
+ num_frames = audio_window_duration * sr
117
+ tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr)
118
+ mfccs = librosa.feature.mfcc(y=waveform, sr=sr, n_mfcc=20)
119
+ mfccs_normalized = (mfccs - mfccs.mean()) / mfccs.std()
120
+ mfccs_padded = librosa.util.fix_length(mfccs_normalized, size=sr*audio_duration, axis=1)
121
+ mfccs_reshaped = mfccs_padded.reshape(1, mfccs_padded.shape[0], mfccs_padded.shape[1])
122
+ for i in range(audio_duration//audio_window_duration):
123
+ mfcc_window = mfccs_reshaped[:,:,i*num_frames:(i+1)*num_frames]
124
+ yield (mfcc_window, tempo)
train.py CHANGED
@@ -11,6 +11,7 @@ from models.audio_spectrogram_transformer import train as train_audio_spectrogra
11
  from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
12
  from preprocessing.preprocess import get_examples
13
  from models.residual import ResidualDancer, TrainingEnvironment
 
14
  import yaml
15
  from preprocessing.dataset import DanceDataModule, WaveformSongDataset, HuggingFaceWaveformSongDataset
16
  from torch.utils.data import random_split
@@ -32,6 +33,8 @@ def get_training_fn(id:str) -> Callable:
32
  return train_ast
33
  case "residual_dancer":
34
  return train_model
 
 
35
  case _:
36
  raise Exception(f"Couldn't find a training function for '{id}'.")
37
 
@@ -143,9 +146,31 @@ def train_ast_lightning(config:dict):
143
  trainer.fit(train_env, datamodule=data)
144
  trainer.test(train_env, datamodule=data)
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  if __name__ == "__main__":
147
  parser = ArgumentParser(description="Trains models on the dance dataset and saves weights.")
148
- parser.add_argument("--config", help="Path to the yaml file that defines the training configuration.", default="models/config/train.yaml")
 
 
149
  args = parser.parse_args()
150
  config = get_config(args.config)
151
  training_id = config["global"]["id"]
 
11
  from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
12
  from preprocessing.preprocess import get_examples
13
  from models.residual import ResidualDancer, TrainingEnvironment
14
+ from models.decision_tree import DanceTreeClassifier, features_from_path
15
  import yaml
16
  from preprocessing.dataset import DanceDataModule, WaveformSongDataset, HuggingFaceWaveformSongDataset
17
  from torch.utils.data import random_split
 
33
  return train_ast
34
  case "residual_dancer":
35
  return train_model
36
+ case "decision_tree":
37
+ return train_decision_tree
38
  case _:
39
  raise Exception(f"Couldn't find a training function for '{id}'.")
40
 
 
146
  trainer.fit(train_env, datamodule=data)
147
  trainer.test(train_env, datamodule=data)
148
 
149
+
150
+ def train_decision_tree(config:dict):
151
+ TARGET_CLASSES = config["global"]["dance_ids"]
152
+ DEVICE = config["global"]["device"]
153
+ SEED = config["global"]["seed"]
154
+ song_data_path=config['data_module']["song_data_path"]
155
+ song_audio_path = config['data_module']["song_audio_path"]
156
+ pl.seed_everything(SEED, workers=True)
157
+
158
+ df = pd.read_csv(song_data_path)
159
+ x, y = get_examples(df, song_audio_path,class_list=TARGET_CLASSES, multi_label=True)
160
+ # Convert y back to string classes
161
+ y = np.array(TARGET_CLASSES)[y.argmax(-1)]
162
+ train_i, test_i = random_split(np.arange(len(x)), [0.8, 0.2])
163
+ train_paths, train_y = x[train_i], y[train_i]
164
+ train_x = features_from_path(train_paths)
165
+ model = DanceTreeClassifier(device=DEVICE)
166
+ model.fit(train_x, train_y)
167
+ model.save()
168
+
169
  if __name__ == "__main__":
170
  parser = ArgumentParser(description="Trains models on the dance dataset and saves weights.")
171
+ parser.add_argument("--config",
172
+ help="Path to the yaml file that defines the training configuration.",
173
+ default="models/config/train.yaml")
174
  args = parser.parse_args()
175
  config = get_config(args.config)
176
  training_id = config["global"]["id"]