waidhoferj commited on
Commit
a8c0792
1 Parent(s): 51f4763

fixed weighing strategy

Browse files
environment.yml CHANGED
@@ -1,4 +1,4 @@
1
- name: dancer-net
2
  channels:
3
  - pytorch
4
  - nvidia
 
1
+ name: dance-classifier
2
  channels:
3
  - pytorch
4
  - nvidia
models/audio_spectrogram_transformer.py CHANGED
@@ -20,8 +20,7 @@ from preprocessing.dataset import (
20
  HuggingFaceDatasetWrapper,
21
  get_datasets,
22
  )
23
- from preprocessing.dataset import get_music4dance_examples
24
- from .utils import get_id_label_mapping, compute_hf_metrics
25
 
26
  import pytorch_lightning as pl
27
  from pytorch_lightning import callbacks as cb
@@ -54,7 +53,7 @@ class AST(nn.Module):
54
  class ASTExtractorWrapper:
55
  def __init__(self, sampling_rate=16000, return_tensors="pt") -> None:
56
  max_length = 1024
57
- self.extractor = ASTFeatureExtractor(do_normalize=False, max_length=max_length)
58
  self.sampling_rate = sampling_rate
59
  self.return_tensors = return_tensors
60
  self.waveform_pipeline = WaveformTrainingPipeline() # TODO configure from yaml
@@ -68,8 +67,6 @@ class ASTExtractorWrapper:
68
  )
69
 
70
  x = x["input_values"].squeeze(0).to(device)
71
- # normalize
72
- x = (x - x.mean()) / x.std()
73
  return x
74
 
75
 
@@ -90,9 +87,7 @@ def train_lightning_ast(config: dict):
90
  )
91
  model = AST(TARGET_CLASSES).to(DEVICE)
92
  label_weights = data.get_label_weights().to(DEVICE)
93
- criterion = nn.CrossEntropyLoss(
94
- label_weights
95
- ) # LabelWeightedBCELoss(label_weights)
96
  if "checkpoint" in config:
97
  train_env = TrainingEnvironment.load_from_checkpoint(
98
  config["checkpoint"], criterion=criterion, model=model, config=config
@@ -100,8 +95,7 @@ def train_lightning_ast(config: dict):
100
  else:
101
  train_env = TrainingEnvironment(model, criterion, config)
102
  callbacks = [
103
- # cb.LearningRateFinder(update_attr=True),
104
- cb.EarlyStopping("val/loss", patience=5),
105
  cb.RichProgressBar(),
106
  ]
107
  trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
 
20
  HuggingFaceDatasetWrapper,
21
  get_datasets,
22
  )
23
+ from .utils import LabelWeightedBCELoss, get_id_label_mapping, compute_hf_metrics
 
24
 
25
  import pytorch_lightning as pl
26
  from pytorch_lightning import callbacks as cb
 
53
  class ASTExtractorWrapper:
54
  def __init__(self, sampling_rate=16000, return_tensors="pt") -> None:
55
  max_length = 1024
56
+ self.extractor = ASTFeatureExtractor(max_length=max_length, do_normalize=True)
57
  self.sampling_rate = sampling_rate
58
  self.return_tensors = return_tensors
59
  self.waveform_pipeline = WaveformTrainingPipeline() # TODO configure from yaml
 
67
  )
68
 
69
  x = x["input_values"].squeeze(0).to(device)
 
 
70
  return x
71
 
72
 
 
87
  )
88
  model = AST(TARGET_CLASSES).to(DEVICE)
89
  label_weights = data.get_label_weights().to(DEVICE)
90
+ criterion = LabelWeightedBCELoss(label_weights)
 
 
91
  if "checkpoint" in config:
92
  train_env = TrainingEnvironment.load_from_checkpoint(
93
  config["checkpoint"], criterion=criterion, model=model, config=config
 
95
  else:
96
  train_env = TrainingEnvironment(model, criterion, config)
97
  callbacks = [
98
+ cb.EarlyStopping("val/loss", patience=2),
 
99
  cb.RichProgressBar(),
100
  ]
101
  trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
models/config/train_local.yaml CHANGED
@@ -4,7 +4,7 @@ seed: 42
4
  dance_ids: &dance_ids
5
  - BCH
6
  - BOL
7
- - CHA
8
  - ECS
9
  - HST
10
  - LHP
 
4
  dance_ids: &dance_ids
5
  - BCH
6
  - BOL
7
+ # - CHA
8
  - ECS
9
  - HST
10
  - LHP
models/training_environment.py CHANGED
@@ -60,9 +60,6 @@ class TrainingEnvironment(pl.LightningModule):
60
  multi_label=self.has_multi_label_predictions,
61
  )
62
  self.log_dict(metrics, prog_bar=True)
63
- experiment = self.logger.experiment
64
- for logger in self.experiment_loggers:
65
- logger.step(experiment, batch_index, features, labels)
66
  return loss
67
 
68
  def validation_step(
@@ -117,8 +114,10 @@ class TrainingEnvironment(pl.LightningModule):
117
  dance_ids = sorted(self.config["dance_ids"])
118
  np.fill_diagonal(self.test_cm, 0)
119
  cm = self.test_cm / self.test_cm.max()
120
- ConfusionMatrixDisplay(cm, display_labels=dance_ids).plot()
121
- image = plot_to_image(plt.gcf())
 
 
122
  image = torch.tensor(image, dtype=torch.uint8)
123
  image = image.permute(2, 0, 1)
124
  self.logger.experiment.add_image("test/confusion_matrix", image, 0)
 
60
  multi_label=self.has_multi_label_predictions,
61
  )
62
  self.log_dict(metrics, prog_bar=True)
 
 
 
63
  return loss
64
 
65
  def validation_step(
 
114
  dance_ids = sorted(self.config["dance_ids"])
115
  np.fill_diagonal(self.test_cm, 0)
116
  cm = self.test_cm / self.test_cm.max()
117
+ cm_plot = ConfusionMatrixDisplay(cm, display_labels=dance_ids)
118
+ fig, ax = plt.subplots(figsize=(12, 12))
119
+ cm_plot.plot(ax=ax)
120
+ image = plot_to_image(fig)
121
  image = torch.tensor(image, dtype=torch.uint8)
122
  image = image.permute(2, 0, 1)
123
  self.logger.experiment.add_image("test/confusion_matrix", image, 0)
preprocessing/dataset.py CHANGED
@@ -99,10 +99,21 @@ class SongDataset(Dataset):
99
  total_slices += audio_slices
100
 
101
  def get_label_weights(self):
102
- n_examples, n_classes = self.dance_labels.shape
103
- weights = n_examples / (n_classes * sum(self.dance_labels))
104
- weights[np.isinf(weights)] = 0.0
105
- return torch.from_numpy(weights)
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  def _backtrace_audio_path(self, index: int) -> str:
108
  return self.audio_paths[self._idx2audio_idx(index)]
@@ -230,6 +241,7 @@ class Music4DanceDataset(Dataset):
230
  class_list=None,
231
  multi_label=True,
232
  min_votes=1,
 
233
  **kwargs,
234
  ) -> None:
235
  super().__init__()
@@ -240,6 +252,7 @@ class Music4DanceDataset(Dataset):
240
  class_list=class_list,
241
  multi_label=multi_label,
242
  min_votes=min_votes,
 
243
  )
244
  self.song_dataset = SongDataset(
245
  song_paths,
@@ -256,7 +269,12 @@ class Music4DanceDataset(Dataset):
256
 
257
 
258
  def get_music4dance_examples(
259
- df: pd.DataFrame, audio_dir: str, class_list=None, multi_label=True, min_votes=1
 
 
 
 
 
260
  ) -> tuple[np.ndarray, np.ndarray]:
261
  sampled_songs = df[has_valid_audio(df["Sample"], audio_dir)].copy(deep=True)
262
  sampled_songs["DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
 
99
  total_slices += audio_slices
100
 
101
  def get_label_weights(self):
102
+ n_examples = len(self)
103
+ n_classes = self.dance_labels.shape[1]
104
+ dance_label_counts = self.get_dance_label_counts()
105
+ weights = n_examples / (n_classes * dance_label_counts)
106
+ weights[np.isinf(weights) | np.isnan(weights)] = 1.0
107
+ return torch.from_numpy(weights).type(torch.float32)
108
+
109
+ def get_dance_label_counts(self) -> np.ndarray:
110
+ """
111
+ Returns the number of examples for each dance label.
112
+ """
113
+ examples_per_audio = np.expand_dims(
114
+ np.array(self.audio_durations) // self.audio_window_duration, axis=-1
115
+ )
116
+ return sum(self.dance_labels * examples_per_audio)
117
 
118
  def _backtrace_audio_path(self, index: int) -> str:
119
  return self.audio_paths[self._idx2audio_idx(index)]
 
241
  class_list=None,
242
  multi_label=True,
243
  min_votes=1,
244
+ class_count_limit=None,
245
  **kwargs,
246
  ) -> None:
247
  super().__init__()
 
252
  class_list=class_list,
253
  multi_label=multi_label,
254
  min_votes=min_votes,
255
+ class_count_limit=class_count_limit,
256
  )
257
  self.song_dataset = SongDataset(
258
  song_paths,
 
269
 
270
 
271
  def get_music4dance_examples(
272
+ df: pd.DataFrame,
273
+ audio_dir: str,
274
+ class_list=None,
275
+ multi_label=True,
276
+ min_votes=1,
277
+ class_count_limit=None,
278
  ) -> tuple[np.ndarray, np.ndarray]:
279
  sampled_songs = df[has_valid_audio(df["Sample"], audio_dir)].copy(deep=True)
280
  sampled_songs["DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
preprocessing/preprocess.py CHANGED
@@ -10,9 +10,7 @@ import torchaudio
10
  import torch
11
  from tqdm import tqdm
12
 
13
-
14
- def url_to_filename(url: str) -> str:
15
- return f"{url.split('/')[-1]}.wav"
16
 
17
 
18
  def has_valid_audio(audio_urls: pd.Series, audio_dir: str) -> pd.Series:
 
10
  import torch
11
  from tqdm import tqdm
12
 
13
+ from preprocessing.utils import url_to_filename
 
 
14
 
15
 
16
  def has_valid_audio(audio_urls: pd.Series, audio_dir: str) -> pd.Series:
preprocessing/utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sounddevice as sd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+
5
+
6
+ def url_to_filename(url: str) -> str:
7
+ return f"{url.split('/')[-1]}.wav"
8
+
9
+
10
+ def play_audio(waveform: np.ndarray, sample_rate: int):
11
+ """
12
+ Assumes that waveform is a numpy array normalized between -1 and 1.
13
+ """
14
+ if waveform.max() > 1.0 or waveform.min() < -1.0:
15
+ raise ValueError("waveform must be a numpy array normalized between -1 and 1.")
16
+ sd.play(waveform, sample_rate)
17
+ sd.wait()
18
+
19
+
20
+ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
21
+ """
22
+ Assumes that the spectrogram is in decibels.
23
+ """
24
+ fig, axs = plt.subplots(1, 1)
25
+ axs.set_title(title or "Spectrogram (db)")
26
+ axs.set_ylabel(ylabel)
27
+ axs.set_xlabel("frame")
28
+ im = axs.imshow(spec, origin="lower", aspect=aspect)
29
+ if xmax:
30
+ axs.set_xlim((0, xmax))
31
+ fig.colorbar(im, ax=axs)
32
+ return fig