Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
•
51f4763
1
Parent(s):
17a2a7d
updated production build to use multiple overlapping samples
Browse files- app.py +23 -5
- models/config/train_local.yaml +8 -9
- models/residual.py +2 -5
- preprocessing/dataset.py +1 -5
- preprocessing/pipelines.py +10 -6
app.py
CHANGED
@@ -7,7 +7,7 @@ from functools import cache
|
|
7 |
from pathlib import Path
|
8 |
from models.residual import ResidualDancer
|
9 |
from models.training_environment import TrainingEnvironment
|
10 |
-
from preprocessing.pipelines import SpectrogramProductionPipeline
|
11 |
import torch
|
12 |
from torch import nn
|
13 |
import yaml
|
@@ -17,6 +17,8 @@ CONFIG_FILE = Path("models/weights/ResidualDancer/multilabel/config.yaml")
|
|
17 |
|
18 |
DANCE_MAPPING_FILE = Path("data/dance_mapping.csv")
|
19 |
|
|
|
|
|
20 |
|
21 |
class DancePredictor:
|
22 |
def __init__(
|
@@ -37,6 +39,9 @@ class DancePredictor:
|
|
37 |
self.labels = np.array(labels)
|
38 |
self.device = device
|
39 |
self.model = self.get_model(weight_path)
|
|
|
|
|
|
|
40 |
self.extractor = SpectrogramProductionPipeline()
|
41 |
|
42 |
def get_model(self, weight_path: str) -> nn.Module:
|
@@ -87,10 +92,21 @@ class DancePredictor:
|
|
87 |
waveform = torchaudio.functional.resample(
|
88 |
waveform, sample_rate, self.resample_frequency
|
89 |
)
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
results = self.model(features)
|
93 |
-
|
|
|
|
|
|
|
94 |
results = results.detach().cpu().numpy()
|
95 |
|
96 |
result_mask = results > self.threshold
|
@@ -116,6 +132,9 @@ def predict(audio: tuple[int, np.ndarray]) -> list[str]:
|
|
116 |
if audio is None:
|
117 |
return "Dance Not Found"
|
118 |
sample_rate, waveform = audio
|
|
|
|
|
|
|
119 |
|
120 |
model = get_model(CONFIG_FILE)
|
121 |
results = model(waveform, sample_rate)
|
@@ -133,7 +152,6 @@ def demo():
|
|
133 |
|
134 |
recording_interface = gr.Interface(
|
135 |
fn=predict,
|
136 |
-
description="Record at least **6 seconds** of the song.",
|
137 |
inputs=gr.Audio(source="microphone", label="Song Recording"),
|
138 |
outputs=gr.Label(label="Dances"),
|
139 |
examples=example_audio,
|
|
|
7 |
from pathlib import Path
|
8 |
from models.residual import ResidualDancer
|
9 |
from models.training_environment import TrainingEnvironment
|
10 |
+
from preprocessing.pipelines import SpectrogramProductionPipeline, WaveformPreprocessing
|
11 |
import torch
|
12 |
from torch import nn
|
13 |
import yaml
|
|
|
17 |
|
18 |
DANCE_MAPPING_FILE = Path("data/dance_mapping.csv")
|
19 |
|
20 |
+
MIN_DURATION = 3.0
|
21 |
+
|
22 |
|
23 |
class DancePredictor:
|
24 |
def __init__(
|
|
|
39 |
self.labels = np.array(labels)
|
40 |
self.device = device
|
41 |
self.model = self.get_model(weight_path)
|
42 |
+
self.process_waveform = WaveformPreprocessing(
|
43 |
+
resample_frequency * expected_duration
|
44 |
+
)
|
45 |
self.extractor = SpectrogramProductionPipeline()
|
46 |
|
47 |
def get_model(self, weight_path: str) -> nn.Module:
|
|
|
92 |
waveform = torchaudio.functional.resample(
|
93 |
waveform, sample_rate, self.resample_frequency
|
94 |
)
|
95 |
+
window_size = self.resample_frequency * self.expected_duration
|
96 |
+
n_preds = int(waveform.shape[1] // (window_size / 2))
|
97 |
+
step_size = int(waveform.shape[1] / n_preds)
|
98 |
+
|
99 |
+
inputs = [
|
100 |
+
waveform[:, i * step_size : i * step_size + window_size]
|
101 |
+
for i in range(n_preds)
|
102 |
+
]
|
103 |
+
features = [self.extractor(window) for window in inputs]
|
104 |
+
features = torch.stack(features).to(self.device)
|
105 |
results = self.model(features)
|
106 |
+
# Convert to probabilities
|
107 |
+
results = nn.functional.softmax(results, dim=1)
|
108 |
+
# Take average prediction over all of the windows
|
109 |
+
results = results.mean(dim=0)
|
110 |
results = results.detach().cpu().numpy()
|
111 |
|
112 |
result_mask = results > self.threshold
|
|
|
132 |
if audio is None:
|
133 |
return "Dance Not Found"
|
134 |
sample_rate, waveform = audio
|
135 |
+
duration = len(waveform) / sample_rate
|
136 |
+
if duration < MIN_DURATION:
|
137 |
+
return f"Please record at least {MIN_DURATION} seconds of audio"
|
138 |
|
139 |
model = get_model(CONFIG_FILE)
|
140 |
results = model(waveform, sample_rate)
|
|
|
152 |
|
153 |
recording_interface = gr.Interface(
|
154 |
fn=predict,
|
|
|
155 |
inputs=gr.Audio(source="microphone", label="Song Recording"),
|
156 |
outputs=gr.Label(label="Dances"),
|
157 |
examples=example_audio,
|
models/config/train_local.yaml
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
training_fn: residual.train_residual_dancer
|
2 |
-
checkpoint: lightning_logs/version_176/checkpoints/epoch=12-step=40404.ckpt
|
3 |
device: mps
|
4 |
seed: 42
|
5 |
dance_ids: &dance_ids
|
6 |
- BCH
|
|
|
7 |
- CHA
|
8 |
-
- JIV
|
9 |
- ECS
|
|
|
|
|
|
|
|
|
10 |
- QST
|
11 |
- RMB
|
12 |
- SFT
|
@@ -20,8 +23,7 @@ dance_ids: &dance_ids
|
|
20 |
data_module:
|
21 |
batch_size: 128
|
22 |
num_workers: 10
|
23 |
-
|
24 |
-
test_proportion: 0.001
|
25 |
|
26 |
datasets:
|
27 |
preprocessing.dataset.BestBallroomDataset:
|
@@ -31,7 +33,7 @@ datasets:
|
|
31 |
|
32 |
preprocessing.dataset.Music4DanceDataset:
|
33 |
song_data_path: data/songs_cleaned.csv
|
34 |
-
song_audio_path: data/samples
|
35 |
class_list: *dance_ids
|
36 |
multi_label: True
|
37 |
min_votes: 1
|
@@ -56,7 +58,4 @@ trainer:
|
|
56 |
# overfit_batches: 1
|
57 |
|
58 |
training_environment:
|
59 |
-
learning_rate: 0.
|
60 |
-
# loggers:
|
61 |
-
# models.training_environment.SpectrogramLogger:
|
62 |
-
# frequency: 100
|
|
|
1 |
training_fn: residual.train_residual_dancer
|
|
|
2 |
device: mps
|
3 |
seed: 42
|
4 |
dance_ids: &dance_ids
|
5 |
- BCH
|
6 |
+
- BOL
|
7 |
- CHA
|
|
|
8 |
- ECS
|
9 |
+
- HST
|
10 |
+
- LHP
|
11 |
+
- NC2
|
12 |
+
- JIV
|
13 |
- QST
|
14 |
- RMB
|
15 |
- SFT
|
|
|
23 |
data_module:
|
24 |
batch_size: 128
|
25 |
num_workers: 10
|
26 |
+
test_proportion: 0.15
|
|
|
27 |
|
28 |
datasets:
|
29 |
preprocessing.dataset.BestBallroomDataset:
|
|
|
33 |
|
34 |
preprocessing.dataset.Music4DanceDataset:
|
35 |
song_data_path: data/songs_cleaned.csv
|
36 |
+
song_audio_path: data/samples
|
37 |
class_list: *dance_ids
|
38 |
multi_label: True
|
39 |
min_votes: 1
|
|
|
58 |
# overfit_batches: 1
|
59 |
|
60 |
training_environment:
|
61 |
+
learning_rate: 0.00053
|
|
|
|
|
|
models/residual.py
CHANGED
@@ -119,14 +119,11 @@ def train_residual_dancer(config: dict):
|
|
119 |
data = DanceDataModule(dataset, **config["data_module"])
|
120 |
model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
|
121 |
label_weights = data.get_label_weights().to(DEVICE)
|
122 |
-
criterion = LabelWeightedBCELoss(
|
123 |
-
label_weights
|
124 |
-
) # nn.CrossEntropyLoss(label_weights)
|
125 |
|
126 |
train_env = TrainingEnvironment(model, criterion, config)
|
127 |
callbacks = [
|
128 |
-
|
129 |
-
cb.EarlyStopping("val/loss", patience=1),
|
130 |
cb.StochasticWeightAveraging(1e-2),
|
131 |
cb.RichProgressBar(),
|
132 |
]
|
|
|
119 |
data = DanceDataModule(dataset, **config["data_module"])
|
120 |
model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
|
121 |
label_weights = data.get_label_weights().to(DEVICE)
|
122 |
+
criterion = LabelWeightedBCELoss(label_weights)
|
|
|
|
|
123 |
|
124 |
train_env = TrainingEnvironment(model, criterion, config)
|
125 |
callbacks = [
|
126 |
+
cb.EarlyStopping("val/loss", patience=2),
|
|
|
127 |
cb.StochasticWeightAveraging(1e-2),
|
128 |
cb.RichProgressBar(),
|
129 |
]
|
preprocessing/dataset.py
CHANGED
@@ -424,11 +424,7 @@ def record_audio_durations(folder: str):
|
|
424 |
music_files = iglob(os.path.join(folder, "**", "*.wav"), recursive=True)
|
425 |
for file in music_files:
|
426 |
meta = ta.info(file)
|
427 |
-
durations[file] = meta.num_frames / meta.sample_rate
|
428 |
|
429 |
with open(os.path.join(folder, "audio_durations.json"), "w") as f:
|
430 |
json.dump(durations, f)
|
431 |
-
|
432 |
-
|
433 |
-
class GTZAN:
|
434 |
-
pass
|
|
|
424 |
music_files = iglob(os.path.join(folder, "**", "*.wav"), recursive=True)
|
425 |
for file in music_files:
|
426 |
meta = ta.info(file)
|
427 |
+
durations[os.path.relpath(file, folder)] = meta.num_frames / meta.sample_rate
|
428 |
|
429 |
with open(os.path.join(folder, "audio_durations.json"), "w") as f:
|
430 |
json.dump(durations, f)
|
|
|
|
|
|
|
|
preprocessing/pipelines.py
CHANGED
@@ -95,23 +95,27 @@ class WaveformPreprocessing(torch.nn.Module):
|
|
95 |
self.expected_sample_length = expected_sample_length
|
96 |
|
97 |
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
|
|
98 |
# Take out extra channels
|
99 |
-
if waveform.shape[
|
100 |
-
waveform = waveform.mean(
|
101 |
|
102 |
# ensure it is the correct length
|
103 |
-
waveform = self._rectify_duration(waveform)
|
104 |
return waveform
|
105 |
|
106 |
-
def _rectify_duration(self, waveform: torch.Tensor):
|
107 |
expected_samples = self.expected_sample_length
|
108 |
-
sample_count = waveform.shape[1]
|
109 |
if expected_samples == sample_count:
|
110 |
return waveform
|
111 |
elif expected_samples > sample_count:
|
112 |
pad_amount = expected_samples - sample_count
|
113 |
return torch.nn.functional.pad(
|
114 |
-
waveform,
|
|
|
|
|
|
|
115 |
)
|
116 |
else:
|
117 |
return waveform[:, :expected_samples]
|
|
|
95 |
self.expected_sample_length = expected_sample_length
|
96 |
|
97 |
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
98 |
+
c_dim = 1 if len(waveform.shape) == 3 else 0
|
99 |
# Take out extra channels
|
100 |
+
if waveform.shape[c_dim] > 1:
|
101 |
+
waveform = waveform.mean(c_dim, keepdim=True)
|
102 |
|
103 |
# ensure it is the correct length
|
104 |
+
waveform = self._rectify_duration(waveform, c_dim)
|
105 |
return waveform
|
106 |
|
107 |
+
def _rectify_duration(self, waveform: torch.Tensor, channel_dim: int):
|
108 |
expected_samples = self.expected_sample_length
|
109 |
+
sample_count = waveform.shape[channel_dim + 1]
|
110 |
if expected_samples == sample_count:
|
111 |
return waveform
|
112 |
elif expected_samples > sample_count:
|
113 |
pad_amount = expected_samples - sample_count
|
114 |
return torch.nn.functional.pad(
|
115 |
+
waveform,
|
116 |
+
(channel_dim + 1) * [0] + [pad_amount],
|
117 |
+
mode="constant",
|
118 |
+
value=0.0,
|
119 |
)
|
120 |
else:
|
121 |
return waveform[:, :expected_samples]
|