waidhoferj commited on
Commit
ba35f85
1 Parent(s): 9f53273

updated production weights

Browse files
app.py CHANGED
@@ -2,18 +2,21 @@ from pathlib import Path
2
  import gradio as gr
3
  import numpy as np
4
  import os
 
5
  from functools import cache
6
  from pathlib import Path
7
- from models.audio_spectrogram_transformer import AST, ASTExtractorWrapper
8
  from models.training_environment import TrainingEnvironment
 
9
  import torch
10
  from torch import nn
11
  import yaml
12
  import torchaudio
13
 
14
- CONFIG_FILE = Path("models/config/train_local.yaml")
15
- MODEL_CLS = AST
16
- EXTRACTOR = ASTExtractorWrapper
 
17
 
18
 
19
  class DancePredictor:
@@ -22,7 +25,7 @@ class DancePredictor:
22
  weight_path: str,
23
  labels: list[str],
24
  expected_duration=6,
25
- threshold=0.5,
26
  resample_frequency=16000,
27
  device="cpu",
28
  ):
@@ -35,11 +38,13 @@ class DancePredictor:
35
  self.labels = np.array(labels)
36
  self.device = device
37
  self.model = self.get_model(weight_path)
38
- self.extractor = ASTExtractorWrapper()
39
 
40
  def get_model(self, weight_path: str) -> nn.Module:
41
  weights = torch.load(weight_path, map_location=self.device)["state_dict"]
42
- model = AST(self.labels).to(self.device)
 
 
43
  for key in list(weights):
44
  weights[
45
  key.replace(
@@ -56,10 +61,12 @@ class DancePredictor:
56
  config = yaml.safe_load(f)
57
  weight_path = config["checkpoint"]
58
  labels = sorted(config["dance_ids"])
59
- expected_duration = 6
60
- threshold = 0.5
61
- resample_frequency = 16000
62
- device = "mps"
 
 
63
  return DancePredictor(
64
  weight_path,
65
  labels,
@@ -81,9 +88,6 @@ class DancePredictor:
81
  waveform = torchaudio.functional.resample(
82
  waveform, sample_rate, self.resample_frequency
83
  )
84
- waveform = waveform[
85
- :, : self.resample_frequency * self.expected_duration
86
- ] # TODO PAD
87
  features = self.extractor(waveform)
88
  features = features.unsqueeze(0).to(self.device)
89
  results = self.model(features)
@@ -103,7 +107,15 @@ def get_model(config_path: str) -> DancePredictor:
103
  return model
104
 
105
 
 
 
 
 
 
 
106
  def predict(audio: tuple[int, np.ndarray]) -> list[str]:
 
 
107
  sample_rate, waveform = audio
108
 
109
  model = get_model(CONFIG_FILE)
@@ -116,7 +128,7 @@ def demo():
116
  description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!"
117
  song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
118
  example_audio = [
119
- str(song) for song in song_samples.iterdir() if song.name[0] != "."
120
  ]
121
  all_dances = get_model(CONFIG_FILE).labels
122
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import os
5
+ import pandas as pd
6
  from functools import cache
7
  from pathlib import Path
8
+ from models.residual import ResidualDancer
9
  from models.training_environment import TrainingEnvironment
10
+ from preprocessing.pipelines import SpectrogramProductionPipeline
11
  import torch
12
  from torch import nn
13
  import yaml
14
  import torchaudio
15
 
16
+ CONFIG_FILE = Path("models/config/production.yaml")
17
+ MODEL_CLS = ResidualDancer
18
+
19
+ DANCE_MAPPING_FILE = Path("data/dance_mapping.csv")
20
 
21
 
22
  class DancePredictor:
 
25
  weight_path: str,
26
  labels: list[str],
27
  expected_duration=6,
28
+ threshold=0.1,
29
  resample_frequency=16000,
30
  device="cpu",
31
  ):
 
38
  self.labels = np.array(labels)
39
  self.device = device
40
  self.model = self.get_model(weight_path)
41
+ self.extractor = SpectrogramProductionPipeline()
42
 
43
  def get_model(self, weight_path: str) -> nn.Module:
44
  weights = torch.load(weight_path, map_location=self.device)["state_dict"]
45
+ n_classes = len(self.labels)
46
+ # NOTE: Channels are not taken into account
47
+ model = ResidualDancer(n_classes=n_classes).to(self.device)
48
  for key in list(weights):
49
  weights[
50
  key.replace(
 
61
  config = yaml.safe_load(f)
62
  weight_path = config["checkpoint"]
63
  labels = sorted(config["dance_ids"])
64
+ dance_mapping = get_dance_mapping(DANCE_MAPPING_FILE)
65
+ labels = [dance_mapping[label] for label in labels]
66
+ expected_duration = config.get("expected_duration", 6)
67
+ threshold = config.get("threshold", 0.1)
68
+ resample_frequency = config.get("resample_frequency", 16000)
69
+ device = config.get("device", "cpu")
70
  return DancePredictor(
71
  weight_path,
72
  labels,
 
88
  waveform = torchaudio.functional.resample(
89
  waveform, sample_rate, self.resample_frequency
90
  )
 
 
 
91
  features = self.extractor(waveform)
92
  features = features.unsqueeze(0).to(self.device)
93
  results = self.model(features)
 
107
  return model
108
 
109
 
110
+ @cache
111
+ def get_dance_mapping(mapping_file: str) -> dict[str, str]:
112
+ mapping_df = pd.read_csv(mapping_file)
113
+ return {row["id"]: row["name"] for _, row in mapping_df.iterrows()}
114
+
115
+
116
  def predict(audio: tuple[int, np.ndarray]) -> list[str]:
117
+ if audio is None:
118
+ return "Dance Not Found"
119
  sample_rate, waveform = audio
120
 
121
  model = get_model(CONFIG_FILE)
 
128
  description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!"
129
  song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
130
  example_audio = [
131
+ str(song) for song in song_samples.iterdir() if not song.name.startswith(".")
132
  ]
133
  all_dances = get_model(CONFIG_FILE).labels
134
 
assets/song-samples/besame_mucho.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14ccffab50d9119ec5250fc84e09542dbbf350450102c108ab61846a3c3031c8
3
+ size 5290062
models/config/production.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint: models/weights/ResidualDancer/weights.ckpt
2
+ device: cpu
3
+ seed: 42
4
+ dance_ids: &dance_ids
5
+ - BCH
6
+ - CHA
7
+ - JIV
8
+ - ECS
9
+ - QST
10
+ - RMB
11
+ - SFT
12
+ - SLS
13
+ - SMB
14
+ - SWZ
15
+ - TGO
16
+ - VWZ
17
+ - WCS
18
+
19
+ model:
20
+ n_channels: 128
models/config/train_local.yaml CHANGED
@@ -1,5 +1,5 @@
1
- training_fn: wav2vec2.train_huggingface
2
- checkpoint: lightning_logs/version_172/checkpoints/epoch=3-step=4572.ckpt
3
  device: mps
4
  seed: 42
5
  dance_ids: &dance_ids
@@ -24,10 +24,10 @@ data_module:
24
  test_proportion: 0.2
25
 
26
  datasets:
27
- # preprocessing.dataset.BestBallroomDataset:
28
- # audio_dir: data/ballroom-songs
29
- # class_list: *dance_ids
30
- # audio_window_jitter: 0.7
31
 
32
  preprocessing.dataset.Music4DanceDataset:
33
  song_data_path: data/songs_cleaned.csv
 
1
+ training_fn: residual.train_residual_dancer
2
+ checkpoint: lightning_logs/version_176/checkpoints/epoch=12-step=40404.ckpt
3
  device: mps
4
  seed: 42
5
  dance_ids: &dance_ids
 
24
  test_proportion: 0.2
25
 
26
  datasets:
27
+ preprocessing.dataset.BestBallroomDataset:
28
+ audio_dir: data/ballroom-songs
29
+ class_list: *dance_ids
30
+ audio_window_jitter: 0.7
31
 
32
  preprocessing.dataset.Music4DanceDataset:
33
  song_data_path: data/songs_cleaned.csv
models/residual.py CHANGED
@@ -110,7 +110,7 @@ def train_residual_dancer(config: dict):
110
  TARGET_CLASSES = config["dance_ids"]
111
  DEVICE = config["device"]
112
  SEED = config["seed"]
113
- torch.set_float32_matmul_precision('medium')
114
  pl.seed_everything(SEED, workers=True)
115
  feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
116
  dataset = get_datasets(config["datasets"], feature_extractor)
@@ -123,7 +123,7 @@ def train_residual_dancer(config: dict):
123
  train_env = TrainingEnvironment(model, criterion, config)
124
  callbacks = [
125
  # cb.LearningRateFinder(update_attr=True),
126
- cb.EarlyStopping("val/loss", patience=5),
127
  cb.StochasticWeightAveraging(1e-2),
128
  cb.RichProgressBar(),
129
  cb.DeviceStatsMonitor(),
 
110
  TARGET_CLASSES = config["dance_ids"]
111
  DEVICE = config["device"]
112
  SEED = config["seed"]
113
+ torch.set_float32_matmul_precision("medium")
114
  pl.seed_everything(SEED, workers=True)
115
  feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
116
  dataset = get_datasets(config["datasets"], feature_extractor)
 
123
  train_env = TrainingEnvironment(model, criterion, config)
124
  callbacks = [
125
  # cb.LearningRateFinder(update_attr=True),
126
+ cb.EarlyStopping("val/loss", patience=1),
127
  cb.StochasticWeightAveraging(1e-2),
128
  cb.RichProgressBar(),
129
  cb.DeviceStatsMonitor(),
models/weights/ResidualDancer/weights.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e107090ff62ac0b79f4f40271e8b1dd6c3d10d8146264ec49df3c8febe99aa23
3
- size 193651217
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90a58841ce4f40f2981227b63dd848e474e8868795a57da84053e3281c4889c7
3
+ size 193643085
preprocessing/dataset.py CHANGED
@@ -78,8 +78,8 @@ class SongDataset(Dataset):
78
  return waveform, dance_labels
79
  else:
80
  # WARNING: Could cause train/test split leak
81
- return self[idx-1]
82
-
83
 
84
  def _idx2audio_idx(self, idx: int) -> int:
85
  return self._get_audio_loc_from_idx(idx)[0]
@@ -424,3 +424,7 @@ def record_audio_durations(folder: str):
424
 
425
  with open(os.path.join(folder, "audio_durations.json"), "w") as f:
426
  json.dump(durations, f)
 
 
 
 
 
78
  return waveform, dance_labels
79
  else:
80
  # WARNING: Could cause train/test split leak
81
+ print("Invalid output, trying next index...")
82
+ return self[idx - 1]
83
 
84
  def _idx2audio_idx(self, idx: int) -> int:
85
  return self._get_audio_loc_from_idx(idx)[0]
 
424
 
425
  with open(os.path.join(folder, "audio_durations.json"), "w") as f:
426
  json.dump(durations, f)
427
+
428
+
429
+ class GTZAN:
430
+ pass
preprocessing/pipelines.py CHANGED
@@ -74,6 +74,21 @@ class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
74
  return spec
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class WaveformPreprocessing(torch.nn.Module):
78
  def __init__(self, expected_sample_length: int):
79
  super().__init__()
 
74
  return spec
75
 
76
 
77
+ class SpectrogramProductionPipeline(torch.nn.Module):
78
+ def __init__(self, sample_rate=16000, expected_duration=6, *args, **kwargs) -> None:
79
+ super().__init__(*args, **kwargs)
80
+ self.preprocess_waveform = WaveformPreprocessing(
81
+ sample_rate * expected_duration
82
+ )
83
+ self.audio_to_spectrogram = AudioToSpectrogram(
84
+ sample_rate=sample_rate,
85
+ )
86
+
87
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
88
+ waveform = self.preprocess_waveform(waveform)
89
+ return self.audio_to_spectrogram(waveform)
90
+
91
+
92
  class WaveformPreprocessing(torch.nn.Module):
93
  def __init__(self, expected_sample_length: int):
94
  super().__init__()