Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
•
42c4703
1
Parent(s):
248f682
updated models
Browse files- app.py +106 -14
- models/audio_spectrogram_transformer.py +6 -2
- models/config/train_local.yaml +7 -6
- models/residual.py +0 -64
- models/training_environment.py +5 -3
- models/wav2vec2.py +1 -2
app.py
CHANGED
@@ -1,23 +1,113 @@
|
|
1 |
from pathlib import Path
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
-
from models.residual import DancePredictor
|
5 |
import os
|
6 |
from functools import cache
|
7 |
from pathlib import Path
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
@cache
|
12 |
-
def get_model(config_path:str) -> DancePredictor:
|
13 |
model = DancePredictor.from_config(config_path)
|
14 |
return model
|
15 |
|
|
|
16 |
def predict(audio: tuple[int, np.ndarray]) -> list[str]:
|
17 |
sample_rate, waveform = audio
|
18 |
-
|
19 |
model = get_model(CONFIG_FILE)
|
20 |
-
results = model(waveform,sample_rate)
|
21 |
return results if len(results) else "Dance Not Found"
|
22 |
|
23 |
|
@@ -25,34 +115,36 @@ def demo():
|
|
25 |
title = "Dance Classifier"
|
26 |
description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!"
|
27 |
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
|
28 |
-
example_audio = [
|
|
|
|
|
29 |
all_dances = get_model(CONFIG_FILE).labels
|
30 |
-
|
31 |
recording_interface = gr.Interface(
|
32 |
fn=predict,
|
33 |
description="Record at least **6 seconds** of the song.",
|
34 |
inputs=gr.Audio(source="microphone", label="Song Recording"),
|
35 |
outputs=gr.Label(label="Dances"),
|
36 |
-
examples=example_audio
|
37 |
)
|
38 |
uploading_interface = gr.Interface(
|
39 |
fn=predict,
|
40 |
inputs=gr.Audio(label="Song Audio File"),
|
41 |
outputs=gr.Label(label="Dances"),
|
42 |
-
examples=example_audio
|
43 |
)
|
44 |
-
|
45 |
with gr.Blocks() as app:
|
46 |
gr.Markdown(f"# {title}")
|
47 |
gr.Markdown(description)
|
48 |
-
gr.TabbedInterface(
|
|
|
|
|
49 |
with gr.Accordion("See all dances", open=False):
|
50 |
gr.Markdown("\n".join(f"- {dance}" for dance in all_dances))
|
51 |
|
52 |
-
|
53 |
-
|
54 |
return app
|
55 |
|
56 |
|
57 |
if __name__ == "__main__":
|
58 |
-
demo().launch()
|
|
|
1 |
from pathlib import Path
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
|
|
4 |
import os
|
5 |
from functools import cache
|
6 |
from pathlib import Path
|
7 |
+
from models.audio_spectrogram_transformer import AST, ASTExtractorWrapper
|
8 |
+
from models.training_environment import TrainingEnvironment
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
import yaml
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
CONFIG_FILE = Path("models/config/train_local.yaml")
|
15 |
+
MODEL_CLS = AST
|
16 |
+
EXTRACTOR = ASTExtractorWrapper
|
17 |
+
|
18 |
+
|
19 |
+
class DancePredictor:
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
weight_path: str,
|
23 |
+
labels: list[str],
|
24 |
+
expected_duration=6,
|
25 |
+
threshold=0.5,
|
26 |
+
resample_frequency=16000,
|
27 |
+
device="cpu",
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
self.expected_duration = expected_duration
|
32 |
+
self.threshold = threshold
|
33 |
+
self.resample_frequency = resample_frequency
|
34 |
+
|
35 |
+
self.labels = np.array(labels)
|
36 |
+
self.device = device
|
37 |
+
self.model = self.get_model(weight_path)
|
38 |
+
self.extractor = ASTExtractorWrapper()
|
39 |
+
|
40 |
+
def get_model(self, weight_path: str) -> nn.Module:
|
41 |
+
weights = torch.load(weight_path, map_location=self.device)["state_dict"]
|
42 |
+
model = AST(self.labels).to(self.device)
|
43 |
+
for key in list(weights):
|
44 |
+
weights[
|
45 |
+
key.replace(
|
46 |
+
"model.",
|
47 |
+
"",
|
48 |
+
)
|
49 |
+
] = weights.pop(key)
|
50 |
+
model.load_state_dict(weights, strict=False)
|
51 |
+
return model.to(self.device).eval()
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def from_config(cls, config_path: str) -> "DancePredictor":
|
55 |
+
with open(config_path, "r") as f:
|
56 |
+
config = yaml.safe_load(f)
|
57 |
+
weight_path = config["checkpoint"]
|
58 |
+
labels = sorted(config["dance_ids"])
|
59 |
+
expected_duration = 6
|
60 |
+
threshold = 0.5
|
61 |
+
resample_frequency = 16000
|
62 |
+
device = "mps"
|
63 |
+
return DancePredictor(
|
64 |
+
weight_path,
|
65 |
+
labels,
|
66 |
+
expected_duration,
|
67 |
+
threshold,
|
68 |
+
resample_frequency,
|
69 |
+
device,
|
70 |
+
)
|
71 |
+
|
72 |
+
@torch.no_grad()
|
73 |
+
def __call__(self, waveform: np.ndarray, sample_rate: int) -> dict[str, float]:
|
74 |
+
if waveform.ndim == 1:
|
75 |
+
waveform = np.stack([waveform, waveform]).T
|
76 |
+
waveform = torch.from_numpy(waveform.T)
|
77 |
+
waveform = torchaudio.functional.apply_codec(
|
78 |
+
waveform, sample_rate, "wav", channels_first=True
|
79 |
+
)
|
80 |
+
|
81 |
+
waveform = torchaudio.functional.resample(
|
82 |
+
waveform, sample_rate, self.resample_frequency
|
83 |
+
)
|
84 |
+
waveform = waveform[
|
85 |
+
:, : self.resample_frequency * self.expected_duration
|
86 |
+
] # TODO PAD
|
87 |
+
features = self.extractor(waveform)
|
88 |
+
features = features.unsqueeze(0).to(self.device)
|
89 |
+
results = self.model(features)
|
90 |
+
results = nn.functional.softmax(results.squeeze(0), dim=0)
|
91 |
+
results = results.detach().cpu().numpy()
|
92 |
+
|
93 |
+
result_mask = results > self.threshold
|
94 |
+
probs = results[result_mask]
|
95 |
+
dances = self.labels[result_mask]
|
96 |
+
|
97 |
+
return {dance: float(prob) for dance, prob in zip(dances, probs)}
|
98 |
|
99 |
|
100 |
@cache
|
101 |
+
def get_model(config_path: str) -> DancePredictor:
|
102 |
model = DancePredictor.from_config(config_path)
|
103 |
return model
|
104 |
|
105 |
+
|
106 |
def predict(audio: tuple[int, np.ndarray]) -> list[str]:
|
107 |
sample_rate, waveform = audio
|
108 |
+
|
109 |
model = get_model(CONFIG_FILE)
|
110 |
+
results = model(waveform, sample_rate)
|
111 |
return results if len(results) else "Dance Not Found"
|
112 |
|
113 |
|
|
|
115 |
title = "Dance Classifier"
|
116 |
description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!"
|
117 |
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
|
118 |
+
example_audio = [
|
119 |
+
str(song) for song in song_samples.iterdir() if song.name[0] != "."
|
120 |
+
]
|
121 |
all_dances = get_model(CONFIG_FILE).labels
|
122 |
+
|
123 |
recording_interface = gr.Interface(
|
124 |
fn=predict,
|
125 |
description="Record at least **6 seconds** of the song.",
|
126 |
inputs=gr.Audio(source="microphone", label="Song Recording"),
|
127 |
outputs=gr.Label(label="Dances"),
|
128 |
+
examples=example_audio,
|
129 |
)
|
130 |
uploading_interface = gr.Interface(
|
131 |
fn=predict,
|
132 |
inputs=gr.Audio(label="Song Audio File"),
|
133 |
outputs=gr.Label(label="Dances"),
|
134 |
+
examples=example_audio,
|
135 |
)
|
136 |
+
|
137 |
with gr.Blocks() as app:
|
138 |
gr.Markdown(f"# {title}")
|
139 |
gr.Markdown(description)
|
140 |
+
gr.TabbedInterface(
|
141 |
+
[uploading_interface, recording_interface], ["Upload Song", "Record Song"]
|
142 |
+
)
|
143 |
with gr.Accordion("See all dances", open=False):
|
144 |
gr.Markdown("\n".join(f"- {dance}" for dance in all_dances))
|
145 |
|
|
|
|
|
146 |
return app
|
147 |
|
148 |
|
149 |
if __name__ == "__main__":
|
150 |
+
demo().launch()
|
models/audio_spectrogram_transformer.py
CHANGED
@@ -88,13 +88,17 @@ def train_lightning_ast(config: dict):
|
|
88 |
target_classes=TARGET_CLASSES,
|
89 |
**config["data_module"],
|
90 |
)
|
91 |
-
|
92 |
model = AST(TARGET_CLASSES).to(DEVICE)
|
93 |
label_weights = data.get_label_weights().to(DEVICE)
|
94 |
criterion = nn.CrossEntropyLoss(
|
95 |
label_weights
|
96 |
) # LabelWeightedBCELoss(label_weights)
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
98 |
callbacks = [
|
99 |
# cb.LearningRateFinder(update_attr=True),
|
100 |
cb.EarlyStopping("val/loss", patience=5),
|
|
|
88 |
target_classes=TARGET_CLASSES,
|
89 |
**config["data_module"],
|
90 |
)
|
|
|
91 |
model = AST(TARGET_CLASSES).to(DEVICE)
|
92 |
label_weights = data.get_label_weights().to(DEVICE)
|
93 |
criterion = nn.CrossEntropyLoss(
|
94 |
label_weights
|
95 |
) # LabelWeightedBCELoss(label_weights)
|
96 |
+
if "checkpoint" in config:
|
97 |
+
train_env = TrainingEnvironment.load_from_checkpoint(
|
98 |
+
config["checkpoint"], criterion=criterion, model=model, config=config
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
train_env = TrainingEnvironment(model, criterion, config)
|
102 |
callbacks = [
|
103 |
# cb.LearningRateFinder(update_attr=True),
|
104 |
cb.EarlyStopping("val/loss", patience=5),
|
models/config/train_local.yaml
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
training_fn:
|
|
|
2 |
device: mps
|
3 |
seed: 42
|
4 |
dance_ids: &dance_ids
|
@@ -23,10 +24,10 @@ data_module:
|
|
23 |
test_proportion: 0.2
|
24 |
|
25 |
datasets:
|
26 |
-
preprocessing.dataset.BestBallroomDataset:
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
|
31 |
preprocessing.dataset.Music4DanceDataset:
|
32 |
song_data_path: data/songs_cleaned.csv
|
@@ -49,7 +50,7 @@ trainer:
|
|
49 |
log_every_n_steps: 15
|
50 |
accelerator: gpu
|
51 |
max_epochs: 50
|
52 |
-
min_epochs:
|
53 |
fast_dev_run: False
|
54 |
# gradient_clip_val: 0.5
|
55 |
# overfit_batches: 1
|
|
|
1 |
+
training_fn: wav2vec2.train_huggingface
|
2 |
+
checkpoint: lightning_logs/version_172/checkpoints/epoch=3-step=4572.ckpt
|
3 |
device: mps
|
4 |
seed: 42
|
5 |
dance_ids: &dance_ids
|
|
|
24 |
test_proportion: 0.2
|
25 |
|
26 |
datasets:
|
27 |
+
# preprocessing.dataset.BestBallroomDataset:
|
28 |
+
# audio_dir: data/ballroom-songs
|
29 |
+
# class_list: *dance_ids
|
30 |
+
# audio_window_jitter: 0.7
|
31 |
|
32 |
preprocessing.dataset.Music4DanceDataset:
|
33 |
song_data_path: data/songs_cleaned.csv
|
|
|
50 |
log_every_n_steps: 15
|
51 |
accelerator: gpu
|
52 |
max_epochs: 50
|
53 |
+
min_epochs: 2
|
54 |
fast_dev_run: False
|
55 |
# gradient_clip_val: 0.5
|
56 |
# overfit_batches: 1
|
models/residual.py
CHANGED
@@ -107,70 +107,6 @@ class ResBlock(nn.Module):
|
|
107 |
return out
|
108 |
|
109 |
|
110 |
-
class DancePredictor:
|
111 |
-
def __init__(
|
112 |
-
self,
|
113 |
-
weight_path: str,
|
114 |
-
labels: list[str],
|
115 |
-
expected_duration=6,
|
116 |
-
threshold=0.5,
|
117 |
-
resample_frequency=16000,
|
118 |
-
device="cpu",
|
119 |
-
):
|
120 |
-
super().__init__()
|
121 |
-
|
122 |
-
self.expected_duration = expected_duration
|
123 |
-
self.threshold = threshold
|
124 |
-
self.resample_frequency = resample_frequency
|
125 |
-
self.preprocess_waveform = WaveformPreprocessing(
|
126 |
-
resample_frequency * expected_duration
|
127 |
-
)
|
128 |
-
self.audio_to_spectrogram = lambda x: x # TODO: Fix
|
129 |
-
self.labels = np.array(labels)
|
130 |
-
self.device = device
|
131 |
-
self.model = self.get_model(weight_path)
|
132 |
-
|
133 |
-
def get_model(self, weight_path: str) -> nn.Module:
|
134 |
-
weights = torch.load(weight_path, map_location=self.device)["state_dict"]
|
135 |
-
model = ResidualDancer(n_classes=len(self.labels))
|
136 |
-
for key in list(weights):
|
137 |
-
weights[key.replace("model.", "")] = weights.pop(key)
|
138 |
-
model.load_state_dict(weights)
|
139 |
-
return model.to(self.device).eval()
|
140 |
-
|
141 |
-
@classmethod
|
142 |
-
def from_config(cls, config_path: str) -> "DancePredictor":
|
143 |
-
with open(config_path, "r") as f:
|
144 |
-
config = yaml.safe_load(f)
|
145 |
-
return DancePredictor(**config)
|
146 |
-
|
147 |
-
@torch.no_grad()
|
148 |
-
def __call__(self, waveform: np.ndarray, sample_rate: int) -> dict[str, float]:
|
149 |
-
if len(waveform.shape) > 1 and waveform.shape[1] < waveform.shape[0]:
|
150 |
-
waveform = waveform.transpose(1, 0)
|
151 |
-
elif len(waveform.shape) == 1:
|
152 |
-
waveform = np.expand_dims(waveform, 0)
|
153 |
-
waveform = torch.from_numpy(waveform.astype("int16"))
|
154 |
-
waveform = torchaudio.functional.apply_codec(
|
155 |
-
waveform, sample_rate, "wav", channels_first=True
|
156 |
-
)
|
157 |
-
|
158 |
-
waveform = torchaudio.functional.resample(
|
159 |
-
waveform, sample_rate, self.resample_frequency
|
160 |
-
)
|
161 |
-
waveform = self.preprocess_waveform(waveform)
|
162 |
-
spectrogram = self.audio_to_spectrogram(waveform)
|
163 |
-
spectrogram = spectrogram.unsqueeze(0).to(self.device)
|
164 |
-
|
165 |
-
results = self.model(spectrogram)
|
166 |
-
results = results.squeeze(0).detach().cpu().numpy()
|
167 |
-
result_mask = results > self.threshold
|
168 |
-
probs = results[result_mask]
|
169 |
-
dances = self.labels[result_mask]
|
170 |
-
|
171 |
-
return {dance: float(prob) for dance, prob in zip(dances, probs)}
|
172 |
-
|
173 |
-
|
174 |
def train_residual_dancer(config: dict):
|
175 |
TARGET_CLASSES = config["dance_ids"]
|
176 |
DEVICE = config["device"]
|
|
|
107 |
return out
|
108 |
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
def train_residual_dancer(config: dict):
|
111 |
TARGET_CLASSES = config["dance_ids"]
|
112 |
DEVICE = config["device"]
|
models/training_environment.py
CHANGED
@@ -17,10 +17,12 @@ class TrainingEnvironment(pl.LightningModule):
|
|
17 |
*args,
|
18 |
**kwargs,
|
19 |
):
|
20 |
-
super().__init__(
|
21 |
self.model = model
|
22 |
self.criterion = criterion
|
23 |
-
self.learning_rate =
|
|
|
|
|
24 |
self.experiment_loggers = load_loggers(
|
25 |
config["training_environment"].get("loggers", {})
|
26 |
)
|
@@ -64,7 +66,7 @@ class TrainingEnvironment(pl.LightningModule):
|
|
64 |
preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
|
65 |
)
|
66 |
metrics["val/loss"] = self.criterion(preds, y)
|
67 |
-
self.log_dict(metrics, prog_bar=True)
|
68 |
|
69 |
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
|
70 |
x, y = batch
|
|
|
17 |
*args,
|
18 |
**kwargs,
|
19 |
):
|
20 |
+
super().__init__()
|
21 |
self.model = model
|
22 |
self.criterion = criterion
|
23 |
+
self.learning_rate = config["training_environment"].get(
|
24 |
+
"learning_rate", learning_rate
|
25 |
+
)
|
26 |
self.experiment_loggers = load_loggers(
|
27 |
config["training_environment"].get("loggers", {})
|
28 |
)
|
|
|
66 |
preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
|
67 |
)
|
68 |
metrics["val/loss"] = self.criterion(preds, y)
|
69 |
+
self.log_dict(metrics, prog_bar=True, sync_dist=True)
|
70 |
|
71 |
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
|
72 |
x, y = batch
|
models/wav2vec2.py
CHANGED
@@ -7,14 +7,13 @@ from transformers import AutoModelForAudioClassification, TrainingArguments, Tra
|
|
7 |
|
8 |
from preprocessing.dataset import (
|
9 |
HuggingFaceDatasetWrapper,
|
10 |
-
BestBallroomDataset,
|
11 |
get_datasets,
|
12 |
)
|
13 |
from preprocessing.pipelines import WaveformTrainingPipeline
|
14 |
|
15 |
from .utils import get_id_label_mapping, compute_hf_metrics
|
16 |
|
17 |
-
MODEL_CHECKPOINT = "
|
18 |
|
19 |
|
20 |
class Wav2VecFeatureExtractor:
|
|
|
7 |
|
8 |
from preprocessing.dataset import (
|
9 |
HuggingFaceDatasetWrapper,
|
|
|
10 |
get_datasets,
|
11 |
)
|
12 |
from preprocessing.pipelines import WaveformTrainingPipeline
|
13 |
|
14 |
from .utils import get_id_label_mapping, compute_hf_metrics
|
15 |
|
16 |
+
MODEL_CHECKPOINT = "m3hrdadfi/wav2vec2-base-100k-voxpopuli-gtzan-music"
|
17 |
|
18 |
|
19 |
class Wav2VecFeatureExtractor:
|