waidhoferj commited on
Commit
3b31903
β€’
1 Parent(s): 3a0f0a5

updated packages

Browse files
.gitignore CHANGED
@@ -8,3 +8,4 @@ scrapers/auth
8
  lightning_logs
9
  .lr_find_*
10
  .cache
 
 
8
  lightning_logs
9
  .lr_find_*
10
  .cache
11
+ .vscode
TODO.md CHANGED
@@ -9,3 +9,7 @@
9
  - Read the Medium series about audio DL
10
  - double check \_rectify_duration
11
  - βœ… Filter out songs that have only one vote
 
 
 
 
 
9
  - Read the Medium series about audio DL
10
  - double check \_rectify_duration
11
  - βœ… Filter out songs that have only one vote
12
+
13
+ ## Notes
14
+
15
+ 2xM60 insufficient memory.
environment.yml CHANGED
@@ -22,7 +22,7 @@ dependencies:
22
  - rich
23
  - scikit-learn
24
  - tensorboard
 
25
  - pip:
26
- - git+https://github.com/huggingface/transformers.git
27
  - evaluate
28
  - wakepy
 
22
  - rich
23
  - scikit-learn
24
  - tensorboard
25
+ - transformers
26
  - pip:
 
27
  - evaluate
28
  - wakepy
models/config/train_local.yaml CHANGED
@@ -1,5 +1,5 @@
1
  global:
2
- id: decision_tree
3
  device: mps
4
  seed: 42
5
  dance_ids:
@@ -22,7 +22,7 @@ data_module:
22
  song_data_path: data/songs_cleaned.csv
23
  song_audio_path: data/samples
24
  batch_size: 32
25
- num_workers: 4
26
  min_votes: 1
27
  dataset_kwargs:
28
  audio_window_duration: 6
 
1
  global:
2
+ id: ast_ptl # decision_tree
3
  device: mps
4
  seed: 42
5
  dance_ids:
 
22
  song_data_path: data/songs_cleaned.csv
23
  song_audio_path: data/samples
24
  batch_size: 32
25
+ num_workers: 7
26
  min_votes: 1
27
  dataset_kwargs:
28
  audio_window_duration: 6
models/decision_tree.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from sklearn.base import ClassifierMixin, BaseEstimator
3
  import pandas as pd
4
  from torch import nn
@@ -8,10 +7,15 @@ import numpy as np
8
  import json
9
  from tqdm import tqdm
10
  import librosa
 
11
  DANCE_INFO_FILE = "data/dance_info.csv"
12
- dance_info_df = pd.read_csv(DANCE_INFO_FILE, converters={'tempoRange': lambda s : json.loads(s.replace("'", '"'))})
 
 
 
 
13
 
14
- class DanceTreeClassifier(BaseEstimator,ClassifierMixin):
15
  """
16
  Trains a series of binary classifiers to classify each dance when a song falls into its bpm range.
17
 
@@ -21,20 +25,20 @@ class DanceTreeClassifier(BaseEstimator,ClassifierMixin):
21
  """
22
 
23
  def __init__(self, device="cpu", lr=1e-4, epochs=5, verbose=True) -> None:
24
- self.device=device
25
- self.epochs=epochs
26
  self.verbose = verbose
27
  self.lr = lr
28
  self.classifiers = {}
29
  self.optimizers = {}
30
  self.criterion = nn.BCELoss()
31
 
32
- def get_valid_dances_from_bpm(self,bpm:float) -> list[str]:
33
- mask = dance_info_df["tempoRange"].apply(lambda interval: interval["min"] <= bpm <= interval["max"])
 
 
34
  return list(dance_info_df["id"][mask])
35
 
36
-
37
-
38
  def fit(self, x, y):
39
  """
40
  x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time)
@@ -45,57 +49,73 @@ class DanceTreeClassifier(BaseEstimator,ClassifierMixin):
45
  # TODO: Introduce batches
46
  epoch_loss = 0
47
  pred_count = 0
 
48
  for (spec, bpm), label in zip(x, y):
 
49
  # find all models that are in the bpm range
50
  matching_dances = self.get_valid_dances_from_bpm(bpm)
 
51
  for dance in matching_dances:
52
  if dance not in self.classifiers or dance not in self.optimizers:
53
- classifier = DanceCNN()
54
  self.classifiers[dance] = classifier
55
- self.optimizers[dance] = torch.optim.Adam(classifier.parameters(), lr=self.lr)
56
- models = [(dance, model, self.optimizers[dance]) for dance, model in self.classifiers.items() if dance in matching_dances]
57
- for dance, model,opt in models:
 
 
 
 
 
 
58
  opt.zero_grad()
59
- spec = torch.from_numpy(spec).to(self.device)
60
  output = model(spec)
61
- target = torch.tensor(float(dance == label))
62
  loss = self.criterion(output, target)
63
  epoch_loss += loss.item()
64
- pred_count +=1
65
  loss.backward()
66
  opt.step()
67
- progress_bar.set_description(f"Loss: {epoch_loss / pred_count}")
 
 
68
 
69
  def predict(self, x) -> list[str]:
70
  results = []
71
  for spec, bpm in zip(*x):
72
  matching_dances = self.get_valid_dances_from_bpm(bpm)
73
- dance_i = torch.tensor([self.classifiers[dance](spec) for dance in matching_dances]).argmax()
 
 
74
  results.append(matching_dances[dance_i])
75
  return results
76
 
77
-
78
-
79
 
80
  class DanceCNN(nn.Module):
81
  def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None:
82
  super().__init__(*args, **kwargs)
83
- kernel_size=(3,9)
84
  self.cnn = nn.Sequential(
85
- nn.Conv2d(1,16, kernel_size=kernel_size),
86
  nn.ReLU(),
87
- nn.MaxPool2d((2,10)),
88
- nn.Conv2d(16,32, kernel_size=kernel_size),
89
  nn.ReLU(),
90
- nn.MaxPool2d((2,10))
 
 
 
 
 
 
91
  )
92
 
93
- embedding_dimension = 32* 3 * 959
94
  self.classifier = nn.Sequential(
95
  nn.Linear(embedding_dimension, 200),
96
  nn.ReLU(),
97
  nn.Linear(200, 1),
98
- nn.Sigmoid()
99
  )
100
 
101
  def forward(self, x):
@@ -103,22 +123,25 @@ class DanceCNN(nn.Module):
103
  x = x.flatten() if len(x.shape) == 3 else x.flatten(1)
104
  return self.classifier(x)
105
 
106
- def features_from_path(paths:list[str],
107
- audio_window_duration=6,
108
- audio_duration=30,
109
- resample_freq=16000) -> Iterator[tuple[np.array, float]]:
110
  """
111
  Loads audio and bpm from an audio path.
112
  """
113
-
114
  for path in paths:
115
  waveform, sr = librosa.load(path, mono=True, sr=resample_freq)
116
- num_frames = audio_window_duration * sr
117
  tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr)
 
118
  mfccs = librosa.feature.mfcc(y=waveform, sr=sr, n_mfcc=20)
119
- mfccs_normalized = (mfccs - mfccs.mean()) / mfccs.std()
120
- mfccs_padded = librosa.util.fix_length(mfccs_normalized, size=sr*audio_duration, axis=1)
121
- mfccs_reshaped = mfccs_padded.reshape(1, mfccs_padded.shape[0], mfccs_padded.shape[1])
122
- for i in range(audio_duration//audio_window_duration):
123
- mfcc_window = mfccs_reshaped[:,:,i*num_frames:(i+1)*num_frames]
124
- yield (mfcc_window, tempo)
 
 
 
 
1
  from sklearn.base import ClassifierMixin, BaseEstimator
2
  import pandas as pd
3
  from torch import nn
 
7
  import json
8
  from tqdm import tqdm
9
  import librosa
10
+
11
  DANCE_INFO_FILE = "data/dance_info.csv"
12
+ dance_info_df = pd.read_csv(
13
+ DANCE_INFO_FILE,
14
+ converters={"tempoRange": lambda s: json.loads(s.replace("'", '"'))},
15
+ )
16
+
17
 
18
+ class DanceTreeClassifier(BaseEstimator, ClassifierMixin):
19
  """
20
  Trains a series of binary classifiers to classify each dance when a song falls into its bpm range.
21
 
 
25
  """
26
 
27
  def __init__(self, device="cpu", lr=1e-4, epochs=5, verbose=True) -> None:
28
+ self.device = device
29
+ self.epochs = epochs
30
  self.verbose = verbose
31
  self.lr = lr
32
  self.classifiers = {}
33
  self.optimizers = {}
34
  self.criterion = nn.BCELoss()
35
 
36
+ def get_valid_dances_from_bpm(self, bpm: float) -> list[str]:
37
+ mask = dance_info_df["tempoRange"].apply(
38
+ lambda interval: interval["min"] <= bpm <= interval["max"]
39
+ )
40
  return list(dance_info_df["id"][mask])
41
 
 
 
42
  def fit(self, x, y):
43
  """
44
  x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time)
 
49
  # TODO: Introduce batches
50
  epoch_loss = 0
51
  pred_count = 0
52
+ step = 0
53
  for (spec, bpm), label in zip(x, y):
54
+ step += 1
55
  # find all models that are in the bpm range
56
  matching_dances = self.get_valid_dances_from_bpm(bpm)
57
+ spec = torch.from_numpy(spec).to(self.device)
58
  for dance in matching_dances:
59
  if dance not in self.classifiers or dance not in self.optimizers:
60
+ classifier = DanceCNN().to(self.device)
61
  self.classifiers[dance] = classifier
62
+ self.optimizers[dance] = torch.optim.Adam(
63
+ classifier.parameters(), lr=self.lr
64
+ )
65
+ models = [
66
+ (dance, model, self.optimizers[dance])
67
+ for dance, model in self.classifiers.items()
68
+ if dance in matching_dances
69
+ ]
70
+ for model_i, (dance, model, opt) in enumerate(models):
71
  opt.zero_grad()
 
72
  output = model(spec)
73
+ target = torch.tensor([float(dance == label)], device=self.device)
74
  loss = self.criterion(output, target)
75
  epoch_loss += loss.item()
76
+ pred_count += 1
77
  loss.backward()
78
  opt.step()
79
+ progress_bar.set_description(
80
+ f"Loss: {epoch_loss / pred_count}, Step: {step}, Model: {model_i+1}/{len(models)}"
81
+ )
82
 
83
  def predict(self, x) -> list[str]:
84
  results = []
85
  for spec, bpm in zip(*x):
86
  matching_dances = self.get_valid_dances_from_bpm(bpm)
87
+ dance_i = torch.tensor(
88
+ [self.classifiers[dance](spec) for dance in matching_dances]
89
+ ).argmax()
90
  results.append(matching_dances[dance_i])
91
  return results
92
 
 
 
93
 
94
  class DanceCNN(nn.Module):
95
  def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None:
96
  super().__init__(*args, **kwargs)
97
+ kernel_size = (3, 9)
98
  self.cnn = nn.Sequential(
99
+ nn.Conv2d(1, 16, kernel_size=kernel_size),
100
  nn.ReLU(),
101
+ nn.MaxPool2d((2, 10)),
102
+ nn.Conv2d(16, 32, kernel_size=kernel_size),
103
  nn.ReLU(),
104
+ nn.MaxPool2d((2, 10)),
105
+ nn.Conv2d(32, 32, kernel_size=kernel_size),
106
+ nn.ReLU(),
107
+ nn.MaxPool2d((2, 10)),
108
+ nn.Conv2d(32, 16, kernel_size=kernel_size),
109
+ nn.ReLU(),
110
+ nn.MaxPool2d((2, 10)),
111
  )
112
 
113
+ embedding_dimension = 16 * 6 * 8
114
  self.classifier = nn.Sequential(
115
  nn.Linear(embedding_dimension, 200),
116
  nn.ReLU(),
117
  nn.Linear(200, 1),
118
+ nn.Sigmoid(),
119
  )
120
 
121
  def forward(self, x):
 
123
  x = x.flatten() if len(x.shape) == 3 else x.flatten(1)
124
  return self.classifier(x)
125
 
126
+
127
+ def features_from_path(
128
+ paths: list[str], audio_window_duration=6, audio_duration=30, resample_freq=16000
129
+ ) -> Iterator[tuple[np.array, float]]:
130
  """
131
  Loads audio and bpm from an audio path.
132
  """
133
+
134
  for path in paths:
135
  waveform, sr = librosa.load(path, mono=True, sr=resample_freq)
136
+ num_frames = audio_window_duration * sr
137
  tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr)
138
+ spec = librosa.feature.melspectrogram(y=waveform, sr=sr)
139
  mfccs = librosa.feature.mfcc(y=waveform, sr=sr, n_mfcc=20)
140
+ spec_normalized = (spec - spec.mean()) / spec.std()
141
+ spec_padded = librosa.util.fix_length(
142
+ spec_normalized, size=sr * audio_duration, axis=1
143
+ )
144
+ batched_spec = np.expand_dims(spec_padded, axis=0)
145
+ for i in range(audio_duration // audio_window_duration):
146
+ spec_window = batched_spec[:, :, i * num_frames : (i + 1) * num_frames]
147
+ yield (spec_window, tempo)
preprocessing/dataset.py CHANGED
@@ -12,19 +12,23 @@ from torch import nn
12
  from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
13
 
14
 
15
-
16
  class SongDataset(Dataset):
17
- def __init__(self,
18
- audio_paths: list[str],
19
- dance_labels: list[np.ndarray],
20
- audio_duration=30, # seconds
21
- audio_window_duration=6, # seconds
22
- audio_window_jitter=0.0, # seconds
23
- audio_pipeline_kwargs={},
24
- resample_frequency=16000
 
25
  ):
26
- assert audio_duration % audio_window_duration == 0, "Audio window should divide duration evenly."
27
- assert audio_window_duration > audio_window_jitter, "Jitter should be a small fraction of the audio window duration."
 
 
 
 
28
 
29
  self.audio_paths = audio_paths
30
  self.dance_labels = dance_labels
@@ -34,14 +38,21 @@ class SongDataset(Dataset):
34
  self.audio_window_jitter = audio_window_jitter
35
  self.audio_duration = int(audio_duration)
36
 
37
- self.audio_pipeline = AudioTrainingPipeline(self.sample_rate, resample_frequency, audio_window_duration, **audio_pipeline_kwargs)
 
 
 
 
 
38
 
39
  def __len__(self):
40
- return len(self.audio_paths) * self.audio_duration // self.audio_window_duration
41
 
42
- def __getitem__(self, idx:int) -> tuple[torch.Tensor, torch.Tensor]:
43
  waveform = self._waveform_from_index(idx)
44
- assert waveform.shape[1] > 10, f"No data found: {self._backtrace_audio_path(idx)}"
 
 
45
  spectrogram = self.audio_pipeline(waveform)
46
 
47
  dance_labels = self._label_from_index(idx)
@@ -53,206 +64,256 @@ class SongDataset(Dataset):
53
  # Try the previous one
54
  # This happens when some of the audio recordings are really quiet
55
  # This WILL NOT leak into other data partitions because songs belong entirely to a partition
56
- return self[idx-1]
57
 
58
- def _convert_idx(self,idx:int) -> int:
59
  return idx * self.audio_window_duration // self.audio_duration
60
 
61
- def _backtrace_audio_path(self, index:int) -> str:
62
  return self.audio_paths[self._convert_idx(index)]
63
 
64
- def _validate_output(self,x,y):
65
- is_finite = not torch.any(torch.isinf(x))
66
  is_numerical = not torch.any(torch.isnan(x))
67
  has_data = torch.any(x != 0.0)
68
  is_binary = len(torch.unique(y)) < 3
69
- return all((is_finite,is_numerical, has_data, is_binary))
70
 
71
- def _waveform_from_index(self, idx:int) -> torch.Tensor:
72
  audio_filepath = self.audio_paths[self._convert_idx(idx)]
73
  num_windows = self.audio_duration // self.audio_window_duration
74
  frame_index = idx % num_windows
75
  jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0
76
  jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0
77
- jitter = int(torch.FloatTensor(1).uniform_(jitter_start, jitter_end) * self.sample_rate)
78
- frame_offset = frame_index * self.audio_window_duration * self.sample_rate + jitter
 
 
 
 
79
  num_frames = self.sample_rate * self.audio_window_duration
80
- waveform, sample_rate = ta.load(audio_filepath, frame_offset=frame_offset, num_frames=num_frames)
81
- assert sample_rate == self.sample_rate, f"Expected sample rate of {self.sample_rate}. Found {sample_rate}"
 
 
 
 
82
  return waveform
83
 
84
-
85
- def _label_from_index(self, idx:int) -> torch.Tensor:
86
  return torch.from_numpy(self.dance_labels[self._convert_idx(idx)])
87
-
88
 
89
  class WaveformSongDataset(SongDataset):
90
  """
91
  Outputs raw waveforms of the data instead of a spectrogram.
92
  """
93
 
94
- def __init__(self, *args,resample_frequency=16000, **kwargs):
95
  super().__init__(*args, **kwargs)
96
  self.resample_frequency = resample_frequency
97
  self.resampler = taT.Resample(self.sample_rate, self.resample_frequency)
98
  self.pipeline = []
99
 
100
- def __getitem__(self, idx:int) -> dict[str, torch.Tensor]:
101
  waveform = self._waveform_from_index(idx)
102
- assert waveform.shape[1] > 10, f"No data found: {self._backtrace_audio_path(idx)}"
 
 
103
  # resample the waveform
104
  waveform = self.resampler(waveform)
105
-
106
  waveform = waveform.mean(0)
107
 
108
  dance_labels = self._label_from_index(idx)
109
  return waveform, dance_labels
110
-
111
-
112
 
113
 
114
  class HuggingFaceWaveformSongDataset(WaveformSongDataset):
115
-
116
  def __init__(self, *args, **kwargs):
117
  super().__init__(*args, **kwargs)
118
  self.pipeline = []
119
 
120
-
121
- def __getitem__(self, idx:int) -> dict[str, torch.Tensor]:
122
- x,y = super().__getitem__(idx)
123
  if len(self.pipeline) > 0:
124
  for fn in self.pipeline:
125
  x = fn(x)
126
 
127
  dance_labels = y.argmax()
128
- return {"input_values": x["input_values"][0] if hasattr(x, "input_values") else x, "label": dance_labels}
 
 
 
129
 
130
- def map(self,fn):
131
  """
132
  NOTE this mutates the original, doesn't return a copy like normal maps.
133
  """
134
  self.pipeline.append(fn)
135
 
 
136
  class DanceDataModule(pl.LightningDataModule):
137
- def __init__(self,
138
- song_data_path="data/songs_cleaned.csv",
139
- song_audio_path="data/samples",
140
- test_proportion=0.15,
141
- val_proportion=0.1,
142
- target_classes:list[str]=None,
143
- min_votes=1,
144
- batch_size:int=64,
145
- num_workers=10,
146
- dataset_cls = None,
147
- dataset_kwargs={}
 
148
  ):
149
  super().__init__()
150
  self.song_data_path = song_data_path
151
  self.song_audio_path = song_audio_path
152
- self.val_proportion=val_proportion
153
- self.test_proportion=test_proportion
154
- self.train_proportion= 1.-test_proportion-val_proportion
155
- self.target_classes=target_classes
156
  self.batch_size = batch_size
157
  self.num_workers = num_workers
158
  self.dataset_kwargs = dataset_kwargs
159
  self.dataset_cls = dataset_cls if dataset_cls is not None else SongDataset
160
 
161
  df = pd.read_csv(song_data_path)
162
- self.x,self.y = get_examples(df, self.song_audio_path,class_list=self.target_classes, multi_label=True, min_votes=min_votes)
 
 
 
 
 
 
163
 
164
  def setup(self, stage: str):
165
- train_i, val_i, test_i = random_split(np.arange(len(self.x)), [self.train_proportion, self.val_proportion, self.test_proportion])
 
 
 
166
  self.train_ds = self._dataset_from_indices(train_i)
167
  self.val_ds = self._dataset_from_indices(val_i)
168
  self.test_ds = self._dataset_from_indices(test_i)
169
-
170
- def _dataset_from_indices(self, idx:list[int]) -> SongDataset:
171
  return self.dataset_cls(self.x[idx], self.y[idx], **self.dataset_kwargs)
172
-
173
  def train_dataloader(self):
174
- return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
 
 
 
 
 
175
 
176
  def val_dataloader(self):
177
- return DataLoader(self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers)
 
 
178
 
179
  def test_dataloader(self):
180
- return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers)
 
 
181
 
182
  def get_label_weights(self):
183
  n_examples, n_classes = self.y.shape
184
  return torch.from_numpy(n_examples / (n_classes * sum(self.y)))
185
-
186
 
187
- class WaveformTrainingEnvironment(pl.LightningModule):
188
 
189
- def __init__(self, model: nn.Module, criterion: nn.Module, feature_extractor, config:dict, learning_rate=1e-4, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
190
  super().__init__(*args, **kwargs)
191
  self.model = model
192
  self.criterion = criterion
193
  self.learning_rate = learning_rate
194
- self.config=config
195
- self.feature_extractor=feature_extractor
196
- self.save_hyperparameters({
197
- "model": type(model).__name__,
198
- "loss": type(criterion).__name__,
199
- "config": config,
200
- **kwargs
201
- })
 
 
202
 
203
  def preprocess_inputs(self, x):
204
  device = x.device
205
- x = x.squeeze(1).cpu().numpy()
206
- x = self.feature_extractor(list(x),return_tensors='pt', sampling_rate=16000)
207
  return x["input_values"].to(device)
208
-
209
- def training_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int) -> torch.Tensor:
 
 
210
  features, labels = batch
211
  features = self.preprocess_inputs(features)
212
  outputs = self.model(features).logits
213
- outputs = nn.Sigmoid()(outputs) # good for multi label classification, should be softmax otherwise
 
 
214
  loss = self.criterion(outputs, labels)
215
  metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
216
  self.log_dict(metrics, prog_bar=True)
217
  return loss
218
 
219
-
220
- def validation_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
221
- x,y = batch
 
222
  x = self.preprocess_inputs(x)
223
  preds = self.model(x).logits
224
- preds = nn.Sigmoid()(preds)
225
  metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
226
  metrics["val/loss"] = self.criterion(preds, y)
227
- self.log_dict(metrics,prog_bar=True)
228
 
229
- def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
230
  x, y = batch
231
  x = self.preprocess_inputs(x)
232
  preds = self.model(x).logits
233
- preds = nn.Sigmoid()(preds)
234
- self.log_dict(calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True)
235
-
 
 
236
  def configure_optimizers(self):
237
  optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
238
  # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
239
- return [optimizer]
240
-
241
 
242
 
243
- def calculate_metrics(pred, target, threshold=0.5, prefix="", multi_label=True) -> dict[str, torch.Tensor]:
 
 
244
  target = target.detach().cpu().numpy()
245
  pred = pred.detach().cpu().numpy()
246
  params = {
247
- "y_true": target if multi_label else target.argmax(1) ,
248
- "y_pred": np.array(pred > threshold, dtype=float) if multi_label else pred.argmax(1),
249
- "zero_division": 0,
250
- "average":"macro"
251
- }
252
- metrics= {
253
- 'precision': precision_score(**params),
254
- 'recall': recall_score(**params),
255
- 'f1': f1_score(**params),
256
- 'accuracy': accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
257
- }
258
- return {prefix + k: torch.tensor(v,dtype=torch.float32) for k,v in metrics.items()}
 
 
 
 
 
12
  from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
13
 
14
 
 
15
  class SongDataset(Dataset):
16
+ def __init__(
17
+ self,
18
+ audio_paths: list[str],
19
+ dance_labels: list[np.ndarray],
20
+ audio_duration=30, # seconds
21
+ audio_window_duration=6, # seconds
22
+ audio_window_jitter=0.0, # seconds
23
+ audio_pipeline_kwargs={},
24
+ resample_frequency=16000,
25
  ):
26
+ assert (
27
+ audio_duration % audio_window_duration == 0
28
+ ), "Audio window should divide duration evenly."
29
+ assert (
30
+ audio_window_duration > audio_window_jitter
31
+ ), "Jitter should be a small fraction of the audio window duration."
32
 
33
  self.audio_paths = audio_paths
34
  self.dance_labels = dance_labels
 
38
  self.audio_window_jitter = audio_window_jitter
39
  self.audio_duration = int(audio_duration)
40
 
41
+ self.audio_pipeline = AudioTrainingPipeline(
42
+ self.sample_rate,
43
+ resample_frequency,
44
+ audio_window_duration,
45
+ **audio_pipeline_kwargs,
46
+ )
47
 
48
  def __len__(self):
49
+ return len(self.audio_paths) * self.audio_duration // self.audio_window_duration
50
 
51
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
52
  waveform = self._waveform_from_index(idx)
53
+ assert (
54
+ waveform.shape[1] > 10
55
+ ), f"No data found: {self._backtrace_audio_path(idx)}"
56
  spectrogram = self.audio_pipeline(waveform)
57
 
58
  dance_labels = self._label_from_index(idx)
 
64
  # Try the previous one
65
  # This happens when some of the audio recordings are really quiet
66
  # This WILL NOT leak into other data partitions because songs belong entirely to a partition
67
+ return self[idx - 1]
68
 
69
+ def _convert_idx(self, idx: int) -> int:
70
  return idx * self.audio_window_duration // self.audio_duration
71
 
72
+ def _backtrace_audio_path(self, index: int) -> str:
73
  return self.audio_paths[self._convert_idx(index)]
74
 
75
+ def _validate_output(self, x, y):
76
+ is_finite = not torch.any(torch.isinf(x))
77
  is_numerical = not torch.any(torch.isnan(x))
78
  has_data = torch.any(x != 0.0)
79
  is_binary = len(torch.unique(y)) < 3
80
+ return all((is_finite, is_numerical, has_data, is_binary))
81
 
82
+ def _waveform_from_index(self, idx: int) -> torch.Tensor:
83
  audio_filepath = self.audio_paths[self._convert_idx(idx)]
84
  num_windows = self.audio_duration // self.audio_window_duration
85
  frame_index = idx % num_windows
86
  jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0
87
  jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0
88
+ jitter = int(
89
+ torch.FloatTensor(1).uniform_(jitter_start, jitter_end) * self.sample_rate
90
+ )
91
+ frame_offset = (
92
+ frame_index * self.audio_window_duration * self.sample_rate + jitter
93
+ )
94
  num_frames = self.sample_rate * self.audio_window_duration
95
+ waveform, sample_rate = ta.load(
96
+ audio_filepath, frame_offset=frame_offset, num_frames=num_frames
97
+ )
98
+ assert (
99
+ sample_rate == self.sample_rate
100
+ ), f"Expected sample rate of {self.sample_rate}. Found {sample_rate}"
101
  return waveform
102
 
103
+ def _label_from_index(self, idx: int) -> torch.Tensor:
 
104
  return torch.from_numpy(self.dance_labels[self._convert_idx(idx)])
105
+
106
 
107
  class WaveformSongDataset(SongDataset):
108
  """
109
  Outputs raw waveforms of the data instead of a spectrogram.
110
  """
111
 
112
+ def __init__(self, *args, resample_frequency=16000, **kwargs):
113
  super().__init__(*args, **kwargs)
114
  self.resample_frequency = resample_frequency
115
  self.resampler = taT.Resample(self.sample_rate, self.resample_frequency)
116
  self.pipeline = []
117
 
118
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
119
  waveform = self._waveform_from_index(idx)
120
+ assert (
121
+ waveform.shape[1] > 10
122
+ ), f"No data found: {self._backtrace_audio_path(idx)}"
123
  # resample the waveform
124
  waveform = self.resampler(waveform)
125
+
126
  waveform = waveform.mean(0)
127
 
128
  dance_labels = self._label_from_index(idx)
129
  return waveform, dance_labels
 
 
130
 
131
 
132
  class HuggingFaceWaveformSongDataset(WaveformSongDataset):
 
133
  def __init__(self, *args, **kwargs):
134
  super().__init__(*args, **kwargs)
135
  self.pipeline = []
136
 
137
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
138
+ x, y = super().__getitem__(idx)
 
139
  if len(self.pipeline) > 0:
140
  for fn in self.pipeline:
141
  x = fn(x)
142
 
143
  dance_labels = y.argmax()
144
+ return {
145
+ "input_values": x["input_values"][0] if hasattr(x, "input_values") else x,
146
+ "label": dance_labels,
147
+ }
148
 
149
+ def map(self, fn):
150
  """
151
  NOTE this mutates the original, doesn't return a copy like normal maps.
152
  """
153
  self.pipeline.append(fn)
154
 
155
+
156
  class DanceDataModule(pl.LightningDataModule):
157
+ def __init__(
158
+ self,
159
+ song_data_path="data/songs_cleaned.csv",
160
+ song_audio_path="data/samples",
161
+ test_proportion=0.15,
162
+ val_proportion=0.1,
163
+ target_classes: list[str] = None,
164
+ min_votes=1,
165
+ batch_size: int = 64,
166
+ num_workers=10,
167
+ dataset_cls=None,
168
+ dataset_kwargs={},
169
  ):
170
  super().__init__()
171
  self.song_data_path = song_data_path
172
  self.song_audio_path = song_audio_path
173
+ self.val_proportion = val_proportion
174
+ self.test_proportion = test_proportion
175
+ self.train_proportion = 1.0 - test_proportion - val_proportion
176
+ self.target_classes = target_classes
177
  self.batch_size = batch_size
178
  self.num_workers = num_workers
179
  self.dataset_kwargs = dataset_kwargs
180
  self.dataset_cls = dataset_cls if dataset_cls is not None else SongDataset
181
 
182
  df = pd.read_csv(song_data_path)
183
+ self.x, self.y = get_examples(
184
+ df,
185
+ self.song_audio_path,
186
+ class_list=self.target_classes,
187
+ multi_label=True,
188
+ min_votes=min_votes,
189
+ )
190
 
191
  def setup(self, stage: str):
192
+ train_i, val_i, test_i = random_split(
193
+ np.arange(len(self.x)),
194
+ [self.train_proportion, self.val_proportion, self.test_proportion],
195
+ )
196
  self.train_ds = self._dataset_from_indices(train_i)
197
  self.val_ds = self._dataset_from_indices(val_i)
198
  self.test_ds = self._dataset_from_indices(test_i)
199
+
200
+ def _dataset_from_indices(self, idx: list[int]) -> SongDataset:
201
  return self.dataset_cls(self.x[idx], self.y[idx], **self.dataset_kwargs)
202
+
203
  def train_dataloader(self):
204
+ return DataLoader(
205
+ self.train_ds,
206
+ batch_size=self.batch_size,
207
+ num_workers=self.num_workers,
208
+ shuffle=True,
209
+ )
210
 
211
  def val_dataloader(self):
212
+ return DataLoader(
213
+ self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers
214
+ )
215
 
216
  def test_dataloader(self):
217
+ return DataLoader(
218
+ self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers
219
+ )
220
 
221
  def get_label_weights(self):
222
  n_examples, n_classes = self.y.shape
223
  return torch.from_numpy(n_examples / (n_classes * sum(self.y)))
 
224
 
 
225
 
226
+ class WaveformTrainingEnvironment(pl.LightningModule):
227
+ def __init__(
228
+ self,
229
+ model: nn.Module,
230
+ criterion: nn.Module,
231
+ feature_extractor,
232
+ config: dict,
233
+ learning_rate=1e-4,
234
+ *args,
235
+ **kwargs,
236
+ ):
237
  super().__init__(*args, **kwargs)
238
  self.model = model
239
  self.criterion = criterion
240
  self.learning_rate = learning_rate
241
+ self.config = config
242
+ self.feature_extractor = feature_extractor
243
+ self.save_hyperparameters(
244
+ {
245
+ "model": type(model).__name__,
246
+ "loss": type(criterion).__name__,
247
+ "config": config,
248
+ **kwargs,
249
+ }
250
+ )
251
 
252
  def preprocess_inputs(self, x):
253
  device = x.device
254
+ x = list(x.squeeze(1).cpu().numpy())
255
+ x = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000)
256
  return x["input_values"].to(device)
257
+
258
+ def training_step(
259
+ self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
260
+ ) -> torch.Tensor:
261
  features, labels = batch
262
  features = self.preprocess_inputs(features)
263
  outputs = self.model(features).logits
264
+ outputs = nn.Sigmoid()(
265
+ outputs
266
+ ) # good for multi label classification, should be softmax otherwise
267
  loss = self.criterion(outputs, labels)
268
  metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
269
  self.log_dict(metrics, prog_bar=True)
270
  return loss
271
 
272
+ def validation_step(
273
+ self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
274
+ ):
275
+ x, y = batch
276
  x = self.preprocess_inputs(x)
277
  preds = self.model(x).logits
278
+ preds = nn.Sigmoid()(preds)
279
  metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
280
  metrics["val/loss"] = self.criterion(preds, y)
281
+ self.log_dict(metrics, prog_bar=True)
282
 
283
+ def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
284
  x, y = batch
285
  x = self.preprocess_inputs(x)
286
  preds = self.model(x).logits
287
+ preds = nn.Sigmoid()(preds)
288
+ self.log_dict(
289
+ calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True
290
+ )
291
+
292
  def configure_optimizers(self):
293
  optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
294
  # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
295
+ return [optimizer]
 
296
 
297
 
298
+ def calculate_metrics(
299
+ pred, target, threshold=0.5, prefix="", multi_label=True
300
+ ) -> dict[str, torch.Tensor]:
301
  target = target.detach().cpu().numpy()
302
  pred = pred.detach().cpu().numpy()
303
  params = {
304
+ "y_true": target if multi_label else target.argmax(1),
305
+ "y_pred": np.array(pred > threshold, dtype=float)
306
+ if multi_label
307
+ else pred.argmax(1),
308
+ "zero_division": 0,
309
+ "average": "macro",
310
+ }
311
+ metrics = {
312
+ "precision": precision_score(**params),
313
+ "recall": recall_score(**params),
314
+ "f1": f1_score(**params),
315
+ "accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
316
+ }
317
+ return {
318
+ prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items()
319
+ }
train.py CHANGED
@@ -7,25 +7,32 @@ from sklearn.model_selection import KFold
7
  import pytorch_lightning as pl
8
  from pytorch_lightning import callbacks as cb
9
  from models.utils import LabelWeightedBCELoss
10
- from models.audio_spectrogram_transformer import train as train_audio_spectrogram_transformer, get_id_label_mapping
 
 
 
11
  from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
12
  from preprocessing.preprocess import get_examples
13
  from models.residual import ResidualDancer, TrainingEnvironment
14
  from models.decision_tree import DanceTreeClassifier, features_from_path
15
  import yaml
16
- from preprocessing.dataset import DanceDataModule, WaveformSongDataset, HuggingFaceWaveformSongDataset
 
 
 
 
17
  from torch.utils.data import random_split
18
  import numpy as np
19
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
20
  from argparse import ArgumentParser
21
 
22
 
23
-
24
  import torch
25
  from torch import nn
26
  from sklearn.utils.class_weight import compute_class_weight
27
 
28
- def get_training_fn(id:str) -> Callable:
 
29
  match id:
30
  case "ast_ptl":
31
  return train_ast_lightning
@@ -38,7 +45,8 @@ def get_training_fn(id:str) -> Callable:
38
  case _:
39
  raise Exception(f"Couldn't find a training function for '{id}'.")
40
 
41
- def get_config(filepath:str) -> dict:
 
42
  with open(filepath, "r") as f:
43
  config = yaml.safe_load(f)
44
  return config
@@ -48,14 +56,14 @@ def cross_validation(config, k=5):
48
  df = pd.read_csv("data/songs.csv")
49
  g_config = config["global"]
50
  batch_size = config["data_module"]["batch_size"]
51
- x,y = get_examples(df, "data/samples",class_list=g_config["dance_ids"])
52
- dataset = SongDataset(x,y)
53
- splits=KFold(n_splits=k,shuffle=True,random_state=g_config["seed"])
54
  trainer = pl.Trainer(accelerator=g_config["device"])
55
- for fold, (train_idx,val_idx) in enumerate(splits.split(x,y)):
56
  print(f"Fold {fold+1}")
57
  model = ResidualDancer(n_classes=len(g_config["dance_ids"]))
58
- train_env = TrainingEnvironment(model,nn.BCELoss())
59
  train_sampler = SubsetRandomSampler(train_idx)
60
  test_sampler = SubsetRandomSampler(val_idx)
61
  train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
@@ -64,15 +72,17 @@ def cross_validation(config, k=5):
64
  trainer.test(train_env, test_loader)
65
 
66
 
67
- def train_model(config:dict):
68
  TARGET_CLASSES = config["global"]["dance_ids"]
69
  DEVICE = config["global"]["device"]
70
  SEED = config["global"]["seed"]
71
  pl.seed_everything(SEED, workers=True)
72
- data = DanceDataModule(target_classes=TARGET_CLASSES, **config['data_module'])
73
- model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config['model'])
74
  label_weights = data.get_label_weights().to(DEVICE)
75
- criterion = LabelWeightedBCELoss(label_weights) #nn.CrossEntropyLoss(label_weights)
 
 
76
  train_env = TrainingEnvironment(model, criterion, config)
77
  callbacks = [
78
  # cb.LearningRateFinder(update_attr=True),
@@ -81,36 +91,41 @@ def train_model(config:dict):
81
  cb.RichProgressBar(),
82
  cb.DeviceStatsMonitor(),
83
  ]
84
- trainer = pl.Trainer(
85
- callbacks=callbacks,
86
- **config["trainer"]
87
- )
88
  trainer.fit(train_env, datamodule=data)
89
  trainer.test(train_env, datamodule=data)
90
 
91
 
92
- def train_ast(
93
- config:dict
94
- ):
95
  TARGET_CLASSES = config["global"]["dance_ids"]
96
  DEVICE = config["global"]["device"]
97
  SEED = config["global"]["seed"]
98
  dataset_kwargs = config["data_module"]["dataset_kwargs"]
99
  test_proportion = config["data_module"].get("test_proportion", 0.2)
100
- train_proportion = 1. - test_proportion
101
- song_data_path="data/songs_cleaned.csv"
102
  song_audio_path = "data/samples"
103
  pl.seed_everything(SEED, workers=True)
104
 
105
  df = pd.read_csv(song_data_path)
106
- x, y = get_examples(df, song_audio_path,class_list=TARGET_CLASSES, multi_label=True)
107
- train_i, test_i = random_split(np.arange(len(x)), [train_proportion, test_proportion])
108
- train_ds = HuggingFaceWaveformSongDataset(x[train_i], y[train_i], **dataset_kwargs, resample_frequency=16000)
109
- test_ds = HuggingFaceWaveformSongDataset(x[test_i], y[test_i], **dataset_kwargs, resample_frequency=16000)
110
- train_audio_spectrogram_transformer(TARGET_CLASSES, train_ds, test_ds, device=DEVICE)
111
-
112
-
113
- def train_ast_lightning(config:dict):
 
 
 
 
 
 
 
 
 
 
114
  """
115
  work on integration between waveform dataset and environment. Should work for both HF and PTL.
116
  """
@@ -118,45 +133,50 @@ def train_ast_lightning(config:dict):
118
  DEVICE = config["global"]["device"]
119
  SEED = config["global"]["seed"]
120
  pl.seed_everything(SEED, workers=True)
121
- data = DanceDataModule(target_classes=TARGET_CLASSES, dataset_cls=WaveformSongDataset, **config['data_module'])
 
 
 
 
122
  id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
123
  model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
124
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
125
 
126
  model = AutoModelForAudioClassification.from_pretrained(
127
- model_checkpoint,
128
- num_labels=len(label2id),
129
- label2id=label2id,
130
- id2label=id2label,
131
- ignore_mismatched_sizes=True
132
- ).to(DEVICE)
133
  label_weights = data.get_label_weights().to(DEVICE)
134
- criterion = LabelWeightedBCELoss(label_weights) #nn.CrossEntropyLoss(label_weights)
135
- train_env = WaveformTrainingEnvironment(model, criterion,feature_extractor, config)
 
 
136
  callbacks = [
137
  # cb.LearningRateFinder(update_attr=True),
138
  cb.EarlyStopping("val/loss", patience=5),
139
  cb.StochasticWeightAveraging(1e-2),
140
- cb.RichProgressBar()
141
  ]
142
- trainer = pl.Trainer(
143
- callbacks=callbacks,
144
- **config["trainer"]
145
- )
146
  trainer.fit(train_env, datamodule=data)
147
  trainer.test(train_env, datamodule=data)
148
 
149
 
150
- def train_decision_tree(config:dict):
151
  TARGET_CLASSES = config["global"]["dance_ids"]
152
  DEVICE = config["global"]["device"]
153
  SEED = config["global"]["seed"]
154
- song_data_path=config['data_module']["song_data_path"]
155
- song_audio_path = config['data_module']["song_audio_path"]
156
  pl.seed_everything(SEED, workers=True)
157
 
158
  df = pd.read_csv(song_data_path)
159
- x, y = get_examples(df, song_audio_path,class_list=TARGET_CLASSES, multi_label=True)
 
 
160
  # Convert y back to string classes
161
  y = np.array(TARGET_CLASSES)[y.argmax(-1)]
162
  train_i, test_i = random_split(np.arange(len(x)), [0.8, 0.2])
@@ -166,13 +186,18 @@ def train_decision_tree(config:dict):
166
  model.fit(train_x, train_y)
167
  model.save()
168
 
 
169
  if __name__ == "__main__":
170
- parser = ArgumentParser(description="Trains models on the dance dataset and saves weights.")
171
- parser.add_argument("--config",
172
- help="Path to the yaml file that defines the training configuration.",
173
- default="models/config/train.yaml")
 
 
 
 
174
  args = parser.parse_args()
175
  config = get_config(args.config)
176
  training_id = config["global"]["id"]
177
  train = get_training_fn(training_id)
178
- train(config)
 
7
  import pytorch_lightning as pl
8
  from pytorch_lightning import callbacks as cb
9
  from models.utils import LabelWeightedBCELoss
10
+ from models.audio_spectrogram_transformer import (
11
+ train as train_audio_spectrogram_transformer,
12
+ get_id_label_mapping,
13
+ )
14
  from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
15
  from preprocessing.preprocess import get_examples
16
  from models.residual import ResidualDancer, TrainingEnvironment
17
  from models.decision_tree import DanceTreeClassifier, features_from_path
18
  import yaml
19
+ from preprocessing.dataset import (
20
+ DanceDataModule,
21
+ WaveformSongDataset,
22
+ HuggingFaceWaveformSongDataset,
23
+ )
24
  from torch.utils.data import random_split
25
  import numpy as np
26
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
27
  from argparse import ArgumentParser
28
 
29
 
 
30
  import torch
31
  from torch import nn
32
  from sklearn.utils.class_weight import compute_class_weight
33
 
34
+
35
+ def get_training_fn(id: str) -> Callable:
36
  match id:
37
  case "ast_ptl":
38
  return train_ast_lightning
 
45
  case _:
46
  raise Exception(f"Couldn't find a training function for '{id}'.")
47
 
48
+
49
+ def get_config(filepath: str) -> dict:
50
  with open(filepath, "r") as f:
51
  config = yaml.safe_load(f)
52
  return config
 
56
  df = pd.read_csv("data/songs.csv")
57
  g_config = config["global"]
58
  batch_size = config["data_module"]["batch_size"]
59
+ x, y = get_examples(df, "data/samples", class_list=g_config["dance_ids"])
60
+ dataset = SongDataset(x, y)
61
+ splits = KFold(n_splits=k, shuffle=True, random_state=g_config["seed"])
62
  trainer = pl.Trainer(accelerator=g_config["device"])
63
+ for fold, (train_idx, val_idx) in enumerate(splits.split(x, y)):
64
  print(f"Fold {fold+1}")
65
  model = ResidualDancer(n_classes=len(g_config["dance_ids"]))
66
+ train_env = TrainingEnvironment(model, nn.BCELoss())
67
  train_sampler = SubsetRandomSampler(train_idx)
68
  test_sampler = SubsetRandomSampler(val_idx)
69
  train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
 
72
  trainer.test(train_env, test_loader)
73
 
74
 
75
+ def train_model(config: dict):
76
  TARGET_CLASSES = config["global"]["dance_ids"]
77
  DEVICE = config["global"]["device"]
78
  SEED = config["global"]["seed"]
79
  pl.seed_everything(SEED, workers=True)
80
+ data = DanceDataModule(target_classes=TARGET_CLASSES, **config["data_module"])
81
+ model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
82
  label_weights = data.get_label_weights().to(DEVICE)
83
+ criterion = LabelWeightedBCELoss(
84
+ label_weights
85
+ ) # nn.CrossEntropyLoss(label_weights)
86
  train_env = TrainingEnvironment(model, criterion, config)
87
  callbacks = [
88
  # cb.LearningRateFinder(update_attr=True),
 
91
  cb.RichProgressBar(),
92
  cb.DeviceStatsMonitor(),
93
  ]
94
+ trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
 
 
 
95
  trainer.fit(train_env, datamodule=data)
96
  trainer.test(train_env, datamodule=data)
97
 
98
 
99
+ def train_ast(config: dict):
 
 
100
  TARGET_CLASSES = config["global"]["dance_ids"]
101
  DEVICE = config["global"]["device"]
102
  SEED = config["global"]["seed"]
103
  dataset_kwargs = config["data_module"]["dataset_kwargs"]
104
  test_proportion = config["data_module"].get("test_proportion", 0.2)
105
+ train_proportion = 1.0 - test_proportion
106
+ song_data_path = "data/songs_cleaned.csv"
107
  song_audio_path = "data/samples"
108
  pl.seed_everything(SEED, workers=True)
109
 
110
  df = pd.read_csv(song_data_path)
111
+ x, y = get_examples(
112
+ df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
113
+ )
114
+ train_i, test_i = random_split(
115
+ np.arange(len(x)), [train_proportion, test_proportion]
116
+ )
117
+ train_ds = HuggingFaceWaveformSongDataset(
118
+ x[train_i], y[train_i], **dataset_kwargs, resample_frequency=16000
119
+ )
120
+ test_ds = HuggingFaceWaveformSongDataset(
121
+ x[test_i], y[test_i], **dataset_kwargs, resample_frequency=16000
122
+ )
123
+ train_audio_spectrogram_transformer(
124
+ TARGET_CLASSES, train_ds, test_ds, device=DEVICE
125
+ )
126
+
127
+
128
+ def train_ast_lightning(config: dict):
129
  """
130
  work on integration between waveform dataset and environment. Should work for both HF and PTL.
131
  """
 
133
  DEVICE = config["global"]["device"]
134
  SEED = config["global"]["seed"]
135
  pl.seed_everything(SEED, workers=True)
136
+ data = DanceDataModule(
137
+ target_classes=TARGET_CLASSES,
138
+ dataset_cls=WaveformSongDataset,
139
+ **config["data_module"],
140
+ )
141
  id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
142
  model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
143
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
144
 
145
  model = AutoModelForAudioClassification.from_pretrained(
146
+ model_checkpoint,
147
+ num_labels=len(label2id),
148
+ label2id=label2id,
149
+ id2label=id2label,
150
+ ignore_mismatched_sizes=True,
151
+ ).to(DEVICE)
152
  label_weights = data.get_label_weights().to(DEVICE)
153
+ criterion = LabelWeightedBCELoss(
154
+ label_weights
155
+ ) # nn.CrossEntropyLoss(label_weights)
156
+ train_env = WaveformTrainingEnvironment(model, criterion, feature_extractor, config)
157
  callbacks = [
158
  # cb.LearningRateFinder(update_attr=True),
159
  cb.EarlyStopping("val/loss", patience=5),
160
  cb.StochasticWeightAveraging(1e-2),
161
+ cb.RichProgressBar(),
162
  ]
163
+ trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
 
 
 
164
  trainer.fit(train_env, datamodule=data)
165
  trainer.test(train_env, datamodule=data)
166
 
167
 
168
+ def train_decision_tree(config: dict):
169
  TARGET_CLASSES = config["global"]["dance_ids"]
170
  DEVICE = config["global"]["device"]
171
  SEED = config["global"]["seed"]
172
+ song_data_path = config["data_module"]["song_data_path"]
173
+ song_audio_path = config["data_module"]["song_audio_path"]
174
  pl.seed_everything(SEED, workers=True)
175
 
176
  df = pd.read_csv(song_data_path)
177
+ x, y = get_examples(
178
+ df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
179
+ )
180
  # Convert y back to string classes
181
  y = np.array(TARGET_CLASSES)[y.argmax(-1)]
182
  train_i, test_i = random_split(np.arange(len(x)), [0.8, 0.2])
 
186
  model.fit(train_x, train_y)
187
  model.save()
188
 
189
+
190
  if __name__ == "__main__":
191
+ parser = ArgumentParser(
192
+ description="Trains models on the dance dataset and saves weights."
193
+ )
194
+ parser.add_argument(
195
+ "--config",
196
+ help="Path to the yaml file that defines the training configuration.",
197
+ default="models/config/train_local.yaml",
198
+ )
199
  args = parser.parse_args()
200
  config = get_config(args.config)
201
  training_id = config["global"]["id"]
202
  train = get_training_fn(training_id)
203
+ train(config)