Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
•
4b8361a
1
Parent(s):
dad3c09
lightning modules, spotify scraping and configs
Browse files- .gitignore +9 -2
- app.py +32 -82
- models/config/dance-predictor.yaml +26 -0
- models/config/train.yaml +23 -0
- models/residual.py +107 -5
- models/utils.py +38 -0
- preprocessing/dataset.py +48 -2
- preprocessing/pipelines.py +63 -0
- preprocessing/preprocess.py +0 -20
- scrapers/music4dance.py +0 -6
- scrapers/spotify.py +87 -0
- scrapers/utils.py +9 -0
.gitignore
CHANGED
@@ -1,8 +1,15 @@
|
|
1 |
__pycache__
|
2 |
.DS_Store
|
3 |
data/samples
|
4 |
-
data/samples
|
5 |
data/samples-backup.zip
|
6 |
data/songs.csv
|
|
|
7 |
logs
|
8 |
-
gradio_cached_examples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
__pycache__
|
2 |
.DS_Store
|
3 |
data/samples
|
4 |
+
data/spotify-samples
|
5 |
data/samples-backup.zip
|
6 |
data/songs.csv
|
7 |
+
data/songs_original.csv
|
8 |
logs
|
9 |
+
gradio_cached_examples
|
10 |
+
explore.ipynb
|
11 |
+
scrapers/auth
|
12 |
+
lightning_logs
|
13 |
+
data/backup_1.csv
|
14 |
+
data/backup.csv
|
15 |
+
data/*.zip
|
app.py
CHANGED
@@ -1,105 +1,55 @@
|
|
1 |
from pathlib import Path
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
-
import
|
5 |
-
from preprocessing.preprocess import AudioPipeline
|
6 |
-
from models.residual import ResidualDancer
|
7 |
import os
|
8 |
-
import json
|
9 |
from functools import cache
|
10 |
-
|
11 |
-
|
12 |
-
DEVICE = "cpu"
|
13 |
-
|
14 |
-
@cache
|
15 |
-
def get_model(device) -> tuple[ResidualDancer, np.ndarray]:
|
16 |
-
model_path = "models/weights/ResidualDancer"
|
17 |
-
weights = os.path.join(model_path, "dancer_net.pt")
|
18 |
-
config_path = os.path.join(model_path, "config.json")
|
19 |
-
|
20 |
-
with open(config_path) as f:
|
21 |
-
config = json.load(f)
|
22 |
-
labels = np.array(sorted(config["classes"]))
|
23 |
-
|
24 |
-
model = ResidualDancer(n_classes=len(labels))
|
25 |
-
model.load_state_dict(torch.load(weights, map_location=DEVICE))
|
26 |
-
model = model.to(device).eval()
|
27 |
-
return model, labels
|
28 |
|
29 |
-
@cache
|
30 |
-
def get_pipeline(sample_rate:int) -> AudioPipeline:
|
31 |
-
return AudioPipeline(input_freq=sample_rate)
|
32 |
|
33 |
@cache
|
34 |
-
def
|
35 |
-
|
36 |
-
return
|
37 |
-
|
38 |
|
39 |
def predict(audio: tuple[int, np.ndarray]) -> list[str]:
|
40 |
sample_rate, waveform = audio
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
audio_pipeline = get_pipeline(sample_rate)
|
48 |
-
model, labels = get_model(DEVICE)
|
49 |
-
|
50 |
-
if sample_len > len(waveform):
|
51 |
-
raise gr.Error("You must record for at least 6 seconds")
|
52 |
-
if len(waveform.shape) > 1 and waveform.shape[1] > 1:
|
53 |
-
waveform = waveform.transpose(1,0)
|
54 |
-
waveform = waveform.mean(axis=0, keepdims=True)
|
55 |
-
else:
|
56 |
-
waveform = np.expand_dims(waveform, 0)
|
57 |
-
waveform = waveform[: ,:sample_len]
|
58 |
-
waveform = (waveform - waveform.min()) / (waveform.max() - waveform.min()) * 2 - 1
|
59 |
-
waveform = waveform.astype("float32")
|
60 |
-
waveform = torch.from_numpy(waveform)
|
61 |
-
spectrogram = audio_pipeline(waveform)
|
62 |
-
spectrogram = spectrogram.unsqueeze(0).to(DEVICE)
|
63 |
-
|
64 |
-
with torch.no_grad():
|
65 |
-
results = model(spectrogram)
|
66 |
-
dance_mapping = get_dance_map()
|
67 |
-
results = results.squeeze(0).detach().cpu().numpy()
|
68 |
-
result_mask = results > threshold
|
69 |
-
probs = results[result_mask]
|
70 |
-
dances = labels[result_mask]
|
71 |
-
|
72 |
-
return {dance_mapping[dance_id]:float(prob) for dance_id, prob in zip(dances, probs)} if len(dances) else "Couldn't find a dance."
|
73 |
|
74 |
|
75 |
def demo():
|
76 |
title = "Dance Classifier"
|
77 |
-
description = "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
with gr.Blocks() as app:
|
79 |
gr.Markdown(f"# {title}")
|
80 |
gr.Markdown(description)
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
with gr.Tab("Upload Song") as t:
|
86 |
-
audio_file = gr.Audio(label="Song Audio File")
|
87 |
-
audio_file_submit = gr.Button("Predict")
|
88 |
-
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
|
89 |
-
example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.']
|
90 |
-
|
91 |
-
labels = gr.Label(label="Dances")
|
92 |
|
93 |
-
|
94 |
-
gr.Examples(
|
95 |
-
examples=example_audio,
|
96 |
-
inputs=audio_file,
|
97 |
-
outputs=labels,
|
98 |
-
fn=predict,
|
99 |
-
)
|
100 |
-
|
101 |
-
audio_file_submit.click(fn=predict, inputs=audio_file, outputs=labels)
|
102 |
-
mic_submit.click(fn=predict, inputs=mic_audio, outputs=labels)
|
103 |
|
104 |
return app
|
105 |
|
|
|
1 |
from pathlib import Path
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
+
from models.residual import DancePredictor
|
|
|
|
|
5 |
import os
|
|
|
6 |
from functools import cache
|
7 |
+
from pathlib import Path
|
8 |
+
CONFIG_FILE = Path("models/config/dance-predictor.yaml")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
|
|
|
|
|
|
10 |
|
11 |
@cache
|
12 |
+
def get_model(config_path:str) -> DancePredictor:
|
13 |
+
model = DancePredictor.from_config(config_path)
|
14 |
+
return model
|
|
|
15 |
|
16 |
def predict(audio: tuple[int, np.ndarray]) -> list[str]:
|
17 |
sample_rate, waveform = audio
|
18 |
|
19 |
+
model = get_model(CONFIG_FILE)
|
20 |
+
results = model(waveform,sample_rate)
|
21 |
+
return results if len(results) else "Dance Not Found"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
def demo():
|
25 |
title = "Dance Classifier"
|
26 |
+
description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!"
|
27 |
+
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
|
28 |
+
example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.']
|
29 |
+
all_dances = get_model(CONFIG_FILE).labels
|
30 |
+
|
31 |
+
recording_interface = gr.Interface(
|
32 |
+
fn=predict,
|
33 |
+
description="Record at least **6 seconds** of the song.",
|
34 |
+
inputs=gr.Audio(source="microphone", label="Song Recording"),
|
35 |
+
outputs=gr.Label(label="Dances"),
|
36 |
+
examples=example_audio
|
37 |
+
)
|
38 |
+
uploading_interface = gr.Interface(
|
39 |
+
fn=predict,
|
40 |
+
inputs=gr.Audio(label="Song Audio File"),
|
41 |
+
outputs=gr.Label(label="Dances"),
|
42 |
+
examples=example_audio
|
43 |
+
)
|
44 |
+
|
45 |
with gr.Blocks() as app:
|
46 |
gr.Markdown(f"# {title}")
|
47 |
gr.Markdown(description)
|
48 |
+
gr.TabbedInterface([uploading_interface, recording_interface], ["Upload Song", "Record Song"])
|
49 |
+
with gr.Accordion("See all dances", open=False):
|
50 |
+
gr.Markdown("\n".join(f"- {dance}" for dance in all_dances))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
return app
|
55 |
|
models/config/dance-predictor.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
weight_path: lightning_logs/version_0/checkpoints/epoch=5-step=870.ckpt
|
2 |
+
expected_duration: 6
|
3 |
+
threshold: 0.5
|
4 |
+
resample_frequency: 16000
|
5 |
+
device: cpu
|
6 |
+
labels:
|
7 |
+
- Argentine Tango
|
8 |
+
- Balboa
|
9 |
+
- Bachata
|
10 |
+
- Blues
|
11 |
+
- Cha Cha
|
12 |
+
- Cumbia
|
13 |
+
- Carolina Shag
|
14 |
+
- East Coast Swing
|
15 |
+
- Hustle
|
16 |
+
- Jive
|
17 |
+
- Lindy Hop
|
18 |
+
- Quickstep
|
19 |
+
- Rumba
|
20 |
+
- Slow Foxtrot
|
21 |
+
- Salsa
|
22 |
+
- Samba
|
23 |
+
- Slow Waltz
|
24 |
+
- Tango (Ballroom)
|
25 |
+
- Viennese Waltz
|
26 |
+
- West Coast Swing
|
models/config/train.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
device: mps
|
2 |
+
seed: 42
|
3 |
+
dance_ids:
|
4 |
+
- ATN
|
5 |
+
- BBA
|
6 |
+
- BCH
|
7 |
+
- BLU
|
8 |
+
- CHA
|
9 |
+
- CMB
|
10 |
+
- CSG
|
11 |
+
- ECS
|
12 |
+
- HST
|
13 |
+
- JIV
|
14 |
+
- LHP
|
15 |
+
- QST
|
16 |
+
- RMB
|
17 |
+
- SFT
|
18 |
+
- SLS
|
19 |
+
- SMB
|
20 |
+
- SWZ
|
21 |
+
- TGO
|
22 |
+
- VWZ
|
23 |
+
- WCS
|
models/residual.py
CHANGED
@@ -1,12 +1,18 @@
|
|
|
|
1 |
import torch.nn as nn
|
2 |
import torch.nn.functional as F
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
# Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
|
5 |
|
6 |
class ResidualDancer(nn.Module):
|
7 |
-
def __init__(self,
|
8 |
-
n_channels=128,
|
9 |
-
n_classes=50):
|
10 |
super().__init__()
|
11 |
|
12 |
# Spectrogram
|
@@ -50,7 +56,7 @@ class ResidualDancer(nn.Module):
|
|
50 |
x = nn.Sigmoid()(x)
|
51 |
|
52 |
return x
|
53 |
-
|
54 |
|
55 |
class ResBlock(nn.Module):
|
56 |
def __init__(self, input_channels, output_channels, shape=3, stride=2):
|
@@ -78,4 +84,100 @@ class ResBlock(nn.Module):
|
|
78 |
x = self.bn_3(self.conv_3(x))
|
79 |
out = x + out
|
80 |
out = self.relu(out)
|
81 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import numpy as np
|
7 |
+
import torchaudio
|
8 |
+
import yaml
|
9 |
+
from .utils import calculate_metrics
|
10 |
+
from preprocessing.pipelines import AudioPipeline
|
11 |
+
|
12 |
# Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
|
13 |
|
14 |
class ResidualDancer(nn.Module):
|
15 |
+
def __init__(self,n_channels=128, n_classes=50):
|
|
|
|
|
16 |
super().__init__()
|
17 |
|
18 |
# Spectrogram
|
|
|
56 |
x = nn.Sigmoid()(x)
|
57 |
|
58 |
return x
|
59 |
+
|
60 |
|
61 |
class ResBlock(nn.Module):
|
62 |
def __init__(self, input_channels, output_channels, shape=3, stride=2):
|
|
|
84 |
x = self.bn_3(self.conv_3(x))
|
85 |
out = x + out
|
86 |
out = self.relu(out)
|
87 |
+
return out
|
88 |
+
|
89 |
+
class TrainingEnvironment(pl.LightningModule):
|
90 |
+
|
91 |
+
def __init__(self, model: nn.Module, criterion: nn.Module, learning_rate=1e-4, *args, **kwargs):
|
92 |
+
super().__init__(*args, **kwargs)
|
93 |
+
self.model = model
|
94 |
+
self.criterion = criterion
|
95 |
+
self.learning_rate = learning_rate
|
96 |
+
|
97 |
+
def training_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int) -> torch.Tensor:
|
98 |
+
features, labels = batch
|
99 |
+
outputs = self.model(features)
|
100 |
+
loss = self.criterion(outputs, labels)
|
101 |
+
batch_metrics = calculate_metrics(outputs, labels)
|
102 |
+
self.log_dict(batch_metrics)
|
103 |
+
return loss
|
104 |
+
|
105 |
+
def validation_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
106 |
+
x, y = batch
|
107 |
+
preds = self.model(x)
|
108 |
+
metrics = calculate_metrics(preds, y, prefix="val_")
|
109 |
+
metrics["val_loss"] = self.criterion(preds, y)
|
110 |
+
self.log_dict(metrics)
|
111 |
+
|
112 |
+
def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
113 |
+
x, y = batch
|
114 |
+
preds = self.model(x)
|
115 |
+
self.log_dict(calculate_metrics(preds, y, prefix="test_"))
|
116 |
+
|
117 |
+
def configure_optimizers(self):
|
118 |
+
return torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
119 |
+
|
120 |
+
|
121 |
+
class DancePredictor:
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
weight_path:str,
|
125 |
+
labels:list[str],
|
126 |
+
expected_duration=6,
|
127 |
+
threshold=0.5,
|
128 |
+
resample_frequency=16000,
|
129 |
+
device="cpu"):
|
130 |
+
|
131 |
+
super().__init__()
|
132 |
+
|
133 |
+
self.expected_duration = expected_duration
|
134 |
+
self.threshold = threshold
|
135 |
+
self.resample_frequency = resample_frequency
|
136 |
+
self.audio_pipeline = AudioPipeline(input_freq=self.resample_frequency)
|
137 |
+
self.labels = np.array(labels)
|
138 |
+
self.device = device
|
139 |
+
self.model = self.get_model(weight_path)
|
140 |
+
|
141 |
+
|
142 |
+
def get_model(self, weight_path:str) -> nn.Module:
|
143 |
+
weights = torch.load(weight_path, map_location=self.device)["state_dict"]
|
144 |
+
model = ResidualDancer(n_classes=len(self.labels))
|
145 |
+
for key in list(weights):
|
146 |
+
weights[key.replace("model.", "")] = weights.pop(key)
|
147 |
+
model.load_state_dict(weights)
|
148 |
+
return model.to(self.device).eval()
|
149 |
+
|
150 |
+
@classmethod
|
151 |
+
def from_config(cls, config_path:str) -> "DancePredictor":
|
152 |
+
with open(config_path, "r") as f:
|
153 |
+
config = yaml.safe_load(f)
|
154 |
+
return DancePredictor(**config)
|
155 |
+
|
156 |
+
@torch.no_grad()
|
157 |
+
def __call__(self, waveform: np.ndarray, sample_rate:int) -> dict[str,float]:
|
158 |
+
min_sample_len = sample_rate * self.expected_duration
|
159 |
+
if min_sample_len > len(waveform):
|
160 |
+
raise Exception("You must record for at least 6 seconds")
|
161 |
+
if len(waveform.shape) > 1 and waveform.shape[1] > 1:
|
162 |
+
waveform = waveform.transpose(1,0)
|
163 |
+
waveform = waveform.mean(axis=0, keepdims=True)
|
164 |
+
else:
|
165 |
+
waveform = np.expand_dims(waveform, 0)
|
166 |
+
waveform = waveform[: ,:min_sample_len]
|
167 |
+
waveform = torch.from_numpy(waveform.astype("int16"))
|
168 |
+
waveform = torchaudio.functional.apply_codec(waveform,sample_rate, "wav", channels_first=True)
|
169 |
+
|
170 |
+
waveform = torchaudio.functional.resample(waveform, sample_rate,self.resample_frequency)
|
171 |
+
spectrogram = self.audio_pipeline(waveform)
|
172 |
+
spectrogram = spectrogram.unsqueeze(0).to(self.device)
|
173 |
+
|
174 |
+
results = self.model(spectrogram)
|
175 |
+
results = results.squeeze(0).detach().cpu().numpy()
|
176 |
+
result_mask = results > self.threshold
|
177 |
+
probs = results[result_mask]
|
178 |
+
dances = self.labels[result_mask]
|
179 |
+
|
180 |
+
return {dance:float(prob) for dance, prob in zip(dances, probs)}
|
181 |
+
|
182 |
+
|
183 |
+
|
models/utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
5 |
+
|
6 |
+
class LabelWeightedBCELoss(nn.Module):
|
7 |
+
def __init__(self, label_weights:torch.Tensor, reduction="mean"):
|
8 |
+
super().__init__()
|
9 |
+
self.label_weights = label_weights
|
10 |
+
|
11 |
+
match reduction:
|
12 |
+
case "mean":
|
13 |
+
self.reduction = torch.mean
|
14 |
+
case "sum":
|
15 |
+
self.reduction = torch.sum
|
16 |
+
|
17 |
+
def _log(self,x:torch.Tensor) -> torch.Tensor:
|
18 |
+
return torch.clamp_min(torch.log(x), -100)
|
19 |
+
|
20 |
+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
21 |
+
losses = -self.label_weights * (target * self._log(input) + (1-target) * self._log(1-input))
|
22 |
+
return self.reduction(losses)
|
23 |
+
|
24 |
+
|
25 |
+
def calculate_metrics(pred, target, threshold=0.5, prefix="") -> dict[str, torch.Tensor]:
|
26 |
+
target = target.detach().cpu().numpy()
|
27 |
+
pred = pred.detach().cpu().numpy()
|
28 |
+
pred = np.array(pred > threshold, dtype=float)
|
29 |
+
metrics= {
|
30 |
+
'precision': precision_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
|
31 |
+
'recall': recall_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
|
32 |
+
'f1': f1_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
|
33 |
+
'accuracy': accuracy_score(y_true=target, y_pred=pred),
|
34 |
+
}
|
35 |
+
if prefix != "":
|
36 |
+
metrics = {prefix + k : v for k, v in metrics.items()}
|
37 |
+
|
38 |
+
return {k: torch.tensor(v,dtype=torch.float32) for k,v in metrics.items()}
|
preprocessing/dataset.py
CHANGED
@@ -1,8 +1,12 @@
|
|
1 |
import torch
|
2 |
-
from torch.utils.data import Dataset
|
3 |
import numpy as np
|
|
|
4 |
import torchaudio as ta
|
5 |
-
from .
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
class SongDataset(Dataset):
|
@@ -47,3 +51,45 @@ class SongDataset(Dataset):
|
|
47 |
def _label_from_index(self, idx:int) -> torch.Tensor:
|
48 |
label_idx = idx * self.audio_window_duration // self.audio_duration
|
49 |
return torch.from_numpy(self.dance_labels[label_idx])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
3 |
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
import torchaudio as ta
|
6 |
+
from .pipelines import AudioPipeline
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
from .preprocess import get_examples
|
9 |
+
|
10 |
|
11 |
|
12 |
class SongDataset(Dataset):
|
|
|
51 |
def _label_from_index(self, idx:int) -> torch.Tensor:
|
52 |
label_idx = idx * self.audio_window_duration // self.audio_duration
|
53 |
return torch.from_numpy(self.dance_labels[label_idx])
|
54 |
+
|
55 |
+
|
56 |
+
class DanceDataModule(pl.LightningDataModule):
|
57 |
+
def __init__(self,
|
58 |
+
song_data_path="data/songs.csv",
|
59 |
+
song_audio_path="data/samples",
|
60 |
+
test_proportion=0.15,
|
61 |
+
val_proportion=0.1,
|
62 |
+
target_classes:list[str]=None,
|
63 |
+
batch_size:int=64,
|
64 |
+
num_workers=10
|
65 |
+
):
|
66 |
+
super().__init__()
|
67 |
+
self.song_data_path = song_data_path
|
68 |
+
self.song_audio_path = song_audio_path
|
69 |
+
self.val_proportion=val_proportion
|
70 |
+
self.test_proportion=test_proportion
|
71 |
+
self.train_proporition= 1.-test_proportion-val_proportion
|
72 |
+
self.target_classes=target_classes
|
73 |
+
self.batch_size = batch_size
|
74 |
+
self.num_workers = num_workers
|
75 |
+
|
76 |
+
df = pd.read_csv("data/songs.csv")
|
77 |
+
self.x,self.y = get_examples(df, self.song_audio_path,class_list=self.target_classes)
|
78 |
+
|
79 |
+
|
80 |
+
def setup(self, stage: str):
|
81 |
+
dataset = SongDataset(self.x,self.y)
|
82 |
+
self.train_ds, self.val_ds, self.test_ds = random_split(dataset, [self.train_proporition, self.val_proportion, self.test_proportion])
|
83 |
+
|
84 |
+
|
85 |
+
def train_dataloader(self):
|
86 |
+
return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers)
|
87 |
+
|
88 |
+
def val_dataloader(self):
|
89 |
+
return DataLoader(self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers)
|
90 |
+
|
91 |
+
def test_dataloader(self):
|
92 |
+
return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers)
|
93 |
+
|
94 |
+
def get_label_weights(self):
|
95 |
+
return torch.from_numpy(len(self.y) / (len(self.y[0]) * sum(self.y)))
|
preprocessing/pipelines.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchaudio import transforms as taT, functional as taF
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class AudioPipeline(torch.nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
input_freq=16000,
|
9 |
+
resample_freq=16000,
|
10 |
+
):
|
11 |
+
super().__init__()
|
12 |
+
self.resample = taT.Resample(orig_freq=input_freq, new_freq=resample_freq)
|
13 |
+
self.spec = taT.MelSpectrogram(sample_rate=resample_freq, n_mels=64, n_fft=1024)
|
14 |
+
self.to_db = taT.AmplitudeToDB()
|
15 |
+
|
16 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
17 |
+
if waveform.shape[0] > 1:
|
18 |
+
waveform = waveform.mean(0, keepdim=True)
|
19 |
+
|
20 |
+
waveform = (waveform - waveform.mean()) / waveform.abs().max()
|
21 |
+
|
22 |
+
waveform = self.resample(waveform)
|
23 |
+
spectrogram = self.spec(waveform)
|
24 |
+
spectrogram = self.to_db(spectrogram)
|
25 |
+
|
26 |
+
return spectrogram
|
27 |
+
|
28 |
+
|
29 |
+
class SpectrogramAugmentationPipeline(torch.nn.Module):
|
30 |
+
|
31 |
+
def __init__(self):
|
32 |
+
super().__init__()
|
33 |
+
self.pipeline = nn.Sequential(
|
34 |
+
taT.FrequencyMasking(80),
|
35 |
+
taT.TimeMasking(80),
|
36 |
+
taT.TimeStretch(80)
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, spectrogram:torch.Tensor) -> torch.Tensor:
|
40 |
+
return self.pipeline(spectrogram)
|
41 |
+
|
42 |
+
|
43 |
+
class WaveformAugmentationPipeline(torch.nn.Module):
|
44 |
+
def __init__(self):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
|
50 |
+
taF.pitch_shift()
|
51 |
+
|
52 |
+
|
53 |
+
class AudioTrainingPipeline(torch.nn.Module):
|
54 |
+
def __init__(self):
|
55 |
+
super().__init__()
|
56 |
+
self.waveform_aug = WaveformAugmentationPipeline()
|
57 |
+
self.spec_aug = SpectrogramAugmentationPipeline()
|
58 |
+
self.audio_preprocessing = AudioPipeline()
|
59 |
+
|
60 |
+
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
|
61 |
+
x = self.audio_preprocessing(waveform)
|
62 |
+
x = self.spec_aug(x)
|
63 |
+
return x
|
preprocessing/preprocess.py
CHANGED
@@ -82,23 +82,3 @@ def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None) -> tuple[list[
|
|
82 |
|
83 |
return audio_paths, list(labels)
|
84 |
|
85 |
-
class AudioPipeline(torch.nn.Module):
|
86 |
-
def __init__(
|
87 |
-
self,
|
88 |
-
input_freq=16000,
|
89 |
-
resample_freq=16000,
|
90 |
-
):
|
91 |
-
super().__init__()
|
92 |
-
self.resample = taT.Resample(orig_freq=input_freq, new_freq=resample_freq)
|
93 |
-
self.spec = taT.MelSpectrogram(sample_rate=resample_freq, n_mels=64, n_fft=1024)
|
94 |
-
self.to_db = taT.AmplitudeToDB()
|
95 |
-
|
96 |
-
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
97 |
-
if waveform.shape[0] > 1:
|
98 |
-
waveform = waveform.mean(0, keepdim=True)
|
99 |
-
waveform = self.resample(waveform)
|
100 |
-
spectrogram = self.spec(waveform)
|
101 |
-
spectrogram = self.to_db(spectrogram)
|
102 |
-
|
103 |
-
return spectrogram
|
104 |
-
|
|
|
82 |
|
83 |
return audio_paths, list(labels)
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scrapers/music4dance.py
CHANGED
@@ -75,12 +75,6 @@ def get_songs(soup: bs) -> dict:
|
|
75 |
return songs
|
76 |
|
77 |
|
78 |
-
def download_song(url: str, out_dir: str):
|
79 |
-
response = requests.get(url)
|
80 |
-
filename = url.split("/")[-1]
|
81 |
-
out_file = Path(out_dir, f"{filename}.mp3")
|
82 |
-
with open(out_file, "wb") as f:
|
83 |
-
f.write(response.content)
|
84 |
|
85 |
def scrape_dance_info() -> pd.DataFrame:
|
86 |
js_obj = re.compile(r"{(.|\n)*}")
|
|
|
75 |
return songs
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
def scrape_dance_info() -> pd.DataFrame:
|
80 |
js_obj = re.compile(r"{(.|\n)*}")
|
scrapers/spotify.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spotipy
|
2 |
+
from spotipy.oauth2 import SpotifyClientCredentials
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from pathlib import Path
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from tqdm import tqdm
|
9 |
+
from utils import download_song
|
10 |
+
import time
|
11 |
+
|
12 |
+
def set_env():
|
13 |
+
here = os.path.dirname(__file__)
|
14 |
+
with open(os.path.join(here, "auth", "spotify.json"), "r") as f:
|
15 |
+
config = json.load(f)
|
16 |
+
os.environ["SPOTIPY_CLIENT_ID"] = config["client_id"]
|
17 |
+
os.environ["SPOTIPY_CLIENT_SECRET"] = config["client_secret"]
|
18 |
+
os.environ["SPOTIPY_REDIRECT_URI"] = "https://localhost:8080/callback"
|
19 |
+
|
20 |
+
set_env()
|
21 |
+
|
22 |
+
|
23 |
+
def get_song_preview_url(song_name:str, spotify:spotipy.Spotify, artist:str = None) -> str | None:
|
24 |
+
info = {
|
25 |
+
"track": song_name
|
26 |
+
}
|
27 |
+
if artist is not None:
|
28 |
+
info["artist"] = artist
|
29 |
+
query = " ".join(f"{k}: {v}" for k,v in info.items())
|
30 |
+
results = spotify.search(query,type="track", limit=1)["tracks"]["items"]
|
31 |
+
valid_results = len(results) > 0 and results[0] is not None and "preview_url" in results[0]
|
32 |
+
if not valid_results:
|
33 |
+
return None
|
34 |
+
song = results[0]
|
35 |
+
return song["preview_url"]
|
36 |
+
|
37 |
+
def patch_missing_songs(
|
38 |
+
df: pd.DataFrame,
|
39 |
+
) -> pd.DataFrame:
|
40 |
+
spotify = spotipy.Spotify(auth_manager=SpotifyClientCredentials())
|
41 |
+
# find songs with missing previews
|
42 |
+
audio_urls = df["Sample"].replace(".", np.nan)
|
43 |
+
missing_audio = pd.isna(audio_urls)
|
44 |
+
missing_df = df[missing_audio]
|
45 |
+
def patch_preview(row: pd.Series):
|
46 |
+
song:str = row["Title"]
|
47 |
+
artist:str = row["Artist"]
|
48 |
+
preview_url = get_song_preview_url(song, spotify, artist)
|
49 |
+
if preview_url is not None:
|
50 |
+
row["Sample"] = preview_url
|
51 |
+
return row
|
52 |
+
backup_file = open("data/backup_1.csv", "a")
|
53 |
+
rows = []
|
54 |
+
indices = []
|
55 |
+
total_rows = len(missing_df)
|
56 |
+
for i, row in tqdm(missing_df.iloc[11121:].iterrows(),total=total_rows):
|
57 |
+
patched_row = patch_preview(row)
|
58 |
+
backup_file.write(f"{i}, {patched_row['Sample']}\n")
|
59 |
+
rows.append(patch_preview(row))
|
60 |
+
indices.append(i)
|
61 |
+
|
62 |
+
|
63 |
+
patched_df = pd.DataFrame(rows,index=indices)
|
64 |
+
df.update(patched_df)
|
65 |
+
return df
|
66 |
+
|
67 |
+
|
68 |
+
def download_links():
|
69 |
+
start = 3180
|
70 |
+
with open("data/backup_2.csv") as f:
|
71 |
+
links = [x.split(",")[1].strip() for x in f.readlines()]
|
72 |
+
links = links[start:]
|
73 |
+
links = [l for l in links if "https" in l]
|
74 |
+
links = links[2680:]
|
75 |
+
for link in tqdm(links, "Songs Downloaded"):
|
76 |
+
download_song(link, "data/spotify-samples")
|
77 |
+
time.sleep(5e-3) # hopefully wont be rate limited with delay 🤞
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
df = pd.read_csv("data/songs.csv")
|
84 |
+
patched = patch_missing_songs(df)
|
85 |
+
patched.to_csv("data/last_part.csv")
|
86 |
+
|
87 |
+
|
scrapers/utils.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
def download_song(url: str, out_dir: str, file_type="mp3"):
|
5 |
+
response = requests.get(url)
|
6 |
+
filename = url.split("/")[-1]
|
7 |
+
out_file = Path(out_dir, f"{filename}.{file_type}")
|
8 |
+
with open(out_file, "wb") as f:
|
9 |
+
f.write(response.content)
|