waidhoferj commited on
Commit
4b8361a
1 Parent(s): dad3c09

lightning modules, spotify scraping and configs

Browse files
.gitignore CHANGED
@@ -1,8 +1,15 @@
1
  __pycache__
2
  .DS_Store
3
  data/samples
4
- data/samples-backup.zip
5
  data/samples-backup.zip
6
  data/songs.csv
 
7
  logs
8
- gradio_cached_examples
 
 
 
 
 
 
 
1
  __pycache__
2
  .DS_Store
3
  data/samples
4
+ data/spotify-samples
5
  data/samples-backup.zip
6
  data/songs.csv
7
+ data/songs_original.csv
8
  logs
9
+ gradio_cached_examples
10
+ explore.ipynb
11
+ scrapers/auth
12
+ lightning_logs
13
+ data/backup_1.csv
14
+ data/backup.csv
15
+ data/*.zip
app.py CHANGED
@@ -1,105 +1,55 @@
1
  from pathlib import Path
2
  import gradio as gr
3
  import numpy as np
4
- import torch
5
- from preprocessing.preprocess import AudioPipeline
6
- from models.residual import ResidualDancer
7
  import os
8
- import json
9
  from functools import cache
10
- import pandas as pd
11
-
12
- DEVICE = "cpu"
13
-
14
- @cache
15
- def get_model(device) -> tuple[ResidualDancer, np.ndarray]:
16
- model_path = "models/weights/ResidualDancer"
17
- weights = os.path.join(model_path, "dancer_net.pt")
18
- config_path = os.path.join(model_path, "config.json")
19
-
20
- with open(config_path) as f:
21
- config = json.load(f)
22
- labels = np.array(sorted(config["classes"]))
23
-
24
- model = ResidualDancer(n_classes=len(labels))
25
- model.load_state_dict(torch.load(weights, map_location=DEVICE))
26
- model = model.to(device).eval()
27
- return model, labels
28
 
29
- @cache
30
- def get_pipeline(sample_rate:int) -> AudioPipeline:
31
- return AudioPipeline(input_freq=sample_rate)
32
 
33
  @cache
34
- def get_dance_map() -> dict:
35
- df = pd.read_csv("data/dance_mapping.csv")
36
- return df.set_index("id").to_dict()["name"]
37
-
38
 
39
  def predict(audio: tuple[int, np.ndarray]) -> list[str]:
40
  sample_rate, waveform = audio
41
 
42
- expected_duration = 6
43
- threshold = 0.5
44
- sample_len = sample_rate * expected_duration
45
-
46
-
47
- audio_pipeline = get_pipeline(sample_rate)
48
- model, labels = get_model(DEVICE)
49
-
50
- if sample_len > len(waveform):
51
- raise gr.Error("You must record for at least 6 seconds")
52
- if len(waveform.shape) > 1 and waveform.shape[1] > 1:
53
- waveform = waveform.transpose(1,0)
54
- waveform = waveform.mean(axis=0, keepdims=True)
55
- else:
56
- waveform = np.expand_dims(waveform, 0)
57
- waveform = waveform[: ,:sample_len]
58
- waveform = (waveform - waveform.min()) / (waveform.max() - waveform.min()) * 2 - 1
59
- waveform = waveform.astype("float32")
60
- waveform = torch.from_numpy(waveform)
61
- spectrogram = audio_pipeline(waveform)
62
- spectrogram = spectrogram.unsqueeze(0).to(DEVICE)
63
-
64
- with torch.no_grad():
65
- results = model(spectrogram)
66
- dance_mapping = get_dance_map()
67
- results = results.squeeze(0).detach().cpu().numpy()
68
- result_mask = results > threshold
69
- probs = results[result_mask]
70
- dances = labels[result_mask]
71
-
72
- return {dance_mapping[dance_id]:float(prob) for dance_id, prob in zip(dances, probs)} if len(dances) else "Couldn't find a dance."
73
 
74
 
75
  def demo():
76
  title = "Dance Classifier"
77
- description = "Record 6 seconds of a song and find out what dance fits the music."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  with gr.Blocks() as app:
79
  gr.Markdown(f"# {title}")
80
  gr.Markdown(description)
81
- with gr.Tab("Record Song"):
82
- mic_audio = gr.Audio(source="microphone", label="Song Recording")
83
- mic_submit = gr.Button("Predict")
84
-
85
- with gr.Tab("Upload Song") as t:
86
- audio_file = gr.Audio(label="Song Audio File")
87
- audio_file_submit = gr.Button("Predict")
88
- song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
89
- example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.']
90
-
91
- labels = gr.Label(label="Dances")
92
 
93
- gr.Markdown("## Examples")
94
- gr.Examples(
95
- examples=example_audio,
96
- inputs=audio_file,
97
- outputs=labels,
98
- fn=predict,
99
- )
100
-
101
- audio_file_submit.click(fn=predict, inputs=audio_file, outputs=labels)
102
- mic_submit.click(fn=predict, inputs=mic_audio, outputs=labels)
103
 
104
  return app
105
 
 
1
  from pathlib import Path
2
  import gradio as gr
3
  import numpy as np
4
+ from models.residual import DancePredictor
 
 
5
  import os
 
6
  from functools import cache
7
+ from pathlib import Path
8
+ CONFIG_FILE = Path("models/config/dance-predictor.yaml")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
 
 
10
 
11
  @cache
12
+ def get_model(config_path:str) -> DancePredictor:
13
+ model = DancePredictor.from_config(config_path)
14
+ return model
 
15
 
16
  def predict(audio: tuple[int, np.ndarray]) -> list[str]:
17
  sample_rate, waveform = audio
18
 
19
+ model = get_model(CONFIG_FILE)
20
+ results = model(waveform,sample_rate)
21
+ return results if len(results) else "Dance Not Found"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  def demo():
25
  title = "Dance Classifier"
26
+ description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!"
27
+ song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
28
+ example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.']
29
+ all_dances = get_model(CONFIG_FILE).labels
30
+
31
+ recording_interface = gr.Interface(
32
+ fn=predict,
33
+ description="Record at least **6 seconds** of the song.",
34
+ inputs=gr.Audio(source="microphone", label="Song Recording"),
35
+ outputs=gr.Label(label="Dances"),
36
+ examples=example_audio
37
+ )
38
+ uploading_interface = gr.Interface(
39
+ fn=predict,
40
+ inputs=gr.Audio(label="Song Audio File"),
41
+ outputs=gr.Label(label="Dances"),
42
+ examples=example_audio
43
+ )
44
+
45
  with gr.Blocks() as app:
46
  gr.Markdown(f"# {title}")
47
  gr.Markdown(description)
48
+ gr.TabbedInterface([uploading_interface, recording_interface], ["Upload Song", "Record Song"])
49
+ with gr.Accordion("See all dances", open=False):
50
+ gr.Markdown("\n".join(f"- {dance}" for dance in all_dances))
 
 
 
 
 
 
 
 
51
 
52
+
 
 
 
 
 
 
 
 
 
53
 
54
  return app
55
 
models/config/dance-predictor.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ weight_path: lightning_logs/version_0/checkpoints/epoch=5-step=870.ckpt
2
+ expected_duration: 6
3
+ threshold: 0.5
4
+ resample_frequency: 16000
5
+ device: cpu
6
+ labels:
7
+ - Argentine Tango
8
+ - Balboa
9
+ - Bachata
10
+ - Blues
11
+ - Cha Cha
12
+ - Cumbia
13
+ - Carolina Shag
14
+ - East Coast Swing
15
+ - Hustle
16
+ - Jive
17
+ - Lindy Hop
18
+ - Quickstep
19
+ - Rumba
20
+ - Slow Foxtrot
21
+ - Salsa
22
+ - Samba
23
+ - Slow Waltz
24
+ - Tango (Ballroom)
25
+ - Viennese Waltz
26
+ - West Coast Swing
models/config/train.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device: mps
2
+ seed: 42
3
+ dance_ids:
4
+ - ATN
5
+ - BBA
6
+ - BCH
7
+ - BLU
8
+ - CHA
9
+ - CMB
10
+ - CSG
11
+ - ECS
12
+ - HST
13
+ - JIV
14
+ - LHP
15
+ - QST
16
+ - RMB
17
+ - SFT
18
+ - SLS
19
+ - SMB
20
+ - SWZ
21
+ - TGO
22
+ - VWZ
23
+ - WCS
models/residual.py CHANGED
@@ -1,12 +1,18 @@
 
1
  import torch.nn as nn
2
  import torch.nn.functional as F
3
 
 
 
 
 
 
 
 
4
  # Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
5
 
6
  class ResidualDancer(nn.Module):
7
- def __init__(self,
8
- n_channels=128,
9
- n_classes=50):
10
  super().__init__()
11
 
12
  # Spectrogram
@@ -50,7 +56,7 @@ class ResidualDancer(nn.Module):
50
  x = nn.Sigmoid()(x)
51
 
52
  return x
53
-
54
 
55
  class ResBlock(nn.Module):
56
  def __init__(self, input_channels, output_channels, shape=3, stride=2):
@@ -78,4 +84,100 @@ class ResBlock(nn.Module):
78
  x = self.bn_3(self.conv_3(x))
79
  out = x + out
80
  out = self.relu(out)
81
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
+ import pytorch_lightning as pl
6
+ import numpy as np
7
+ import torchaudio
8
+ import yaml
9
+ from .utils import calculate_metrics
10
+ from preprocessing.pipelines import AudioPipeline
11
+
12
  # Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
13
 
14
  class ResidualDancer(nn.Module):
15
+ def __init__(self,n_channels=128, n_classes=50):
 
 
16
  super().__init__()
17
 
18
  # Spectrogram
 
56
  x = nn.Sigmoid()(x)
57
 
58
  return x
59
+
60
 
61
  class ResBlock(nn.Module):
62
  def __init__(self, input_channels, output_channels, shape=3, stride=2):
 
84
  x = self.bn_3(self.conv_3(x))
85
  out = x + out
86
  out = self.relu(out)
87
+ return out
88
+
89
+ class TrainingEnvironment(pl.LightningModule):
90
+
91
+ def __init__(self, model: nn.Module, criterion: nn.Module, learning_rate=1e-4, *args, **kwargs):
92
+ super().__init__(*args, **kwargs)
93
+ self.model = model
94
+ self.criterion = criterion
95
+ self.learning_rate = learning_rate
96
+
97
+ def training_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int) -> torch.Tensor:
98
+ features, labels = batch
99
+ outputs = self.model(features)
100
+ loss = self.criterion(outputs, labels)
101
+ batch_metrics = calculate_metrics(outputs, labels)
102
+ self.log_dict(batch_metrics)
103
+ return loss
104
+
105
+ def validation_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
106
+ x, y = batch
107
+ preds = self.model(x)
108
+ metrics = calculate_metrics(preds, y, prefix="val_")
109
+ metrics["val_loss"] = self.criterion(preds, y)
110
+ self.log_dict(metrics)
111
+
112
+ def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
113
+ x, y = batch
114
+ preds = self.model(x)
115
+ self.log_dict(calculate_metrics(preds, y, prefix="test_"))
116
+
117
+ def configure_optimizers(self):
118
+ return torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
119
+
120
+
121
+ class DancePredictor:
122
+ def __init__(
123
+ self,
124
+ weight_path:str,
125
+ labels:list[str],
126
+ expected_duration=6,
127
+ threshold=0.5,
128
+ resample_frequency=16000,
129
+ device="cpu"):
130
+
131
+ super().__init__()
132
+
133
+ self.expected_duration = expected_duration
134
+ self.threshold = threshold
135
+ self.resample_frequency = resample_frequency
136
+ self.audio_pipeline = AudioPipeline(input_freq=self.resample_frequency)
137
+ self.labels = np.array(labels)
138
+ self.device = device
139
+ self.model = self.get_model(weight_path)
140
+
141
+
142
+ def get_model(self, weight_path:str) -> nn.Module:
143
+ weights = torch.load(weight_path, map_location=self.device)["state_dict"]
144
+ model = ResidualDancer(n_classes=len(self.labels))
145
+ for key in list(weights):
146
+ weights[key.replace("model.", "")] = weights.pop(key)
147
+ model.load_state_dict(weights)
148
+ return model.to(self.device).eval()
149
+
150
+ @classmethod
151
+ def from_config(cls, config_path:str) -> "DancePredictor":
152
+ with open(config_path, "r") as f:
153
+ config = yaml.safe_load(f)
154
+ return DancePredictor(**config)
155
+
156
+ @torch.no_grad()
157
+ def __call__(self, waveform: np.ndarray, sample_rate:int) -> dict[str,float]:
158
+ min_sample_len = sample_rate * self.expected_duration
159
+ if min_sample_len > len(waveform):
160
+ raise Exception("You must record for at least 6 seconds")
161
+ if len(waveform.shape) > 1 and waveform.shape[1] > 1:
162
+ waveform = waveform.transpose(1,0)
163
+ waveform = waveform.mean(axis=0, keepdims=True)
164
+ else:
165
+ waveform = np.expand_dims(waveform, 0)
166
+ waveform = waveform[: ,:min_sample_len]
167
+ waveform = torch.from_numpy(waveform.astype("int16"))
168
+ waveform = torchaudio.functional.apply_codec(waveform,sample_rate, "wav", channels_first=True)
169
+
170
+ waveform = torchaudio.functional.resample(waveform, sample_rate,self.resample_frequency)
171
+ spectrogram = self.audio_pipeline(waveform)
172
+ spectrogram = spectrogram.unsqueeze(0).to(self.device)
173
+
174
+ results = self.model(spectrogram)
175
+ results = results.squeeze(0).detach().cpu().numpy()
176
+ result_mask = results > self.threshold
177
+ probs = results[result_mask]
178
+ dances = self.labels[result_mask]
179
+
180
+ return {dance:float(prob) for dance, prob in zip(dances, probs)}
181
+
182
+
183
+
models/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
5
+
6
+ class LabelWeightedBCELoss(nn.Module):
7
+ def __init__(self, label_weights:torch.Tensor, reduction="mean"):
8
+ super().__init__()
9
+ self.label_weights = label_weights
10
+
11
+ match reduction:
12
+ case "mean":
13
+ self.reduction = torch.mean
14
+ case "sum":
15
+ self.reduction = torch.sum
16
+
17
+ def _log(self,x:torch.Tensor) -> torch.Tensor:
18
+ return torch.clamp_min(torch.log(x), -100)
19
+
20
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
21
+ losses = -self.label_weights * (target * self._log(input) + (1-target) * self._log(1-input))
22
+ return self.reduction(losses)
23
+
24
+
25
+ def calculate_metrics(pred, target, threshold=0.5, prefix="") -> dict[str, torch.Tensor]:
26
+ target = target.detach().cpu().numpy()
27
+ pred = pred.detach().cpu().numpy()
28
+ pred = np.array(pred > threshold, dtype=float)
29
+ metrics= {
30
+ 'precision': precision_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
31
+ 'recall': recall_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
32
+ 'f1': f1_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
33
+ 'accuracy': accuracy_score(y_true=target, y_pred=pred),
34
+ }
35
+ if prefix != "":
36
+ metrics = {prefix + k : v for k, v in metrics.items()}
37
+
38
+ return {k: torch.tensor(v,dtype=torch.float32) for k,v in metrics.items()}
preprocessing/dataset.py CHANGED
@@ -1,8 +1,12 @@
1
  import torch
2
- from torch.utils.data import Dataset
3
  import numpy as np
 
4
  import torchaudio as ta
5
- from .preprocess import AudioPipeline
 
 
 
6
 
7
 
8
  class SongDataset(Dataset):
@@ -47,3 +51,45 @@ class SongDataset(Dataset):
47
  def _label_from_index(self, idx:int) -> torch.Tensor:
48
  label_idx = idx * self.audio_window_duration // self.audio_duration
49
  return torch.from_numpy(self.dance_labels[label_idx])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from torch.utils.data import Dataset, DataLoader, random_split
3
  import numpy as np
4
+ import pandas as pd
5
  import torchaudio as ta
6
+ from .pipelines import AudioPipeline
7
+ import pytorch_lightning as pl
8
+ from .preprocess import get_examples
9
+
10
 
11
 
12
  class SongDataset(Dataset):
 
51
  def _label_from_index(self, idx:int) -> torch.Tensor:
52
  label_idx = idx * self.audio_window_duration // self.audio_duration
53
  return torch.from_numpy(self.dance_labels[label_idx])
54
+
55
+
56
+ class DanceDataModule(pl.LightningDataModule):
57
+ def __init__(self,
58
+ song_data_path="data/songs.csv",
59
+ song_audio_path="data/samples",
60
+ test_proportion=0.15,
61
+ val_proportion=0.1,
62
+ target_classes:list[str]=None,
63
+ batch_size:int=64,
64
+ num_workers=10
65
+ ):
66
+ super().__init__()
67
+ self.song_data_path = song_data_path
68
+ self.song_audio_path = song_audio_path
69
+ self.val_proportion=val_proportion
70
+ self.test_proportion=test_proportion
71
+ self.train_proporition= 1.-test_proportion-val_proportion
72
+ self.target_classes=target_classes
73
+ self.batch_size = batch_size
74
+ self.num_workers = num_workers
75
+
76
+ df = pd.read_csv("data/songs.csv")
77
+ self.x,self.y = get_examples(df, self.song_audio_path,class_list=self.target_classes)
78
+
79
+
80
+ def setup(self, stage: str):
81
+ dataset = SongDataset(self.x,self.y)
82
+ self.train_ds, self.val_ds, self.test_ds = random_split(dataset, [self.train_proporition, self.val_proportion, self.test_proportion])
83
+
84
+
85
+ def train_dataloader(self):
86
+ return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers)
87
+
88
+ def val_dataloader(self):
89
+ return DataLoader(self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers)
90
+
91
+ def test_dataloader(self):
92
+ return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers)
93
+
94
+ def get_label_weights(self):
95
+ return torch.from_numpy(len(self.y) / (len(self.y[0]) * sum(self.y)))
preprocessing/pipelines.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchaudio import transforms as taT, functional as taF
3
+ import torch.nn as nn
4
+
5
+ class AudioPipeline(torch.nn.Module):
6
+ def __init__(
7
+ self,
8
+ input_freq=16000,
9
+ resample_freq=16000,
10
+ ):
11
+ super().__init__()
12
+ self.resample = taT.Resample(orig_freq=input_freq, new_freq=resample_freq)
13
+ self.spec = taT.MelSpectrogram(sample_rate=resample_freq, n_mels=64, n_fft=1024)
14
+ self.to_db = taT.AmplitudeToDB()
15
+
16
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
17
+ if waveform.shape[0] > 1:
18
+ waveform = waveform.mean(0, keepdim=True)
19
+
20
+ waveform = (waveform - waveform.mean()) / waveform.abs().max()
21
+
22
+ waveform = self.resample(waveform)
23
+ spectrogram = self.spec(waveform)
24
+ spectrogram = self.to_db(spectrogram)
25
+
26
+ return spectrogram
27
+
28
+
29
+ class SpectrogramAugmentationPipeline(torch.nn.Module):
30
+
31
+ def __init__(self):
32
+ super().__init__()
33
+ self.pipeline = nn.Sequential(
34
+ taT.FrequencyMasking(80),
35
+ taT.TimeMasking(80),
36
+ taT.TimeStretch(80)
37
+ )
38
+
39
+ def forward(self, spectrogram:torch.Tensor) -> torch.Tensor:
40
+ return self.pipeline(spectrogram)
41
+
42
+
43
+ class WaveformAugmentationPipeline(torch.nn.Module):
44
+ def __init__(self):
45
+ super().__init__()
46
+
47
+
48
+
49
+ def forward(self, waveform:torch.Tensor) -> torch.Tensor:
50
+ taF.pitch_shift()
51
+
52
+
53
+ class AudioTrainingPipeline(torch.nn.Module):
54
+ def __init__(self):
55
+ super().__init__()
56
+ self.waveform_aug = WaveformAugmentationPipeline()
57
+ self.spec_aug = SpectrogramAugmentationPipeline()
58
+ self.audio_preprocessing = AudioPipeline()
59
+
60
+ def forward(self, waveform:torch.Tensor) -> torch.Tensor:
61
+ x = self.audio_preprocessing(waveform)
62
+ x = self.spec_aug(x)
63
+ return x
preprocessing/preprocess.py CHANGED
@@ -82,23 +82,3 @@ def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None) -> tuple[list[
82
 
83
  return audio_paths, list(labels)
84
 
85
- class AudioPipeline(torch.nn.Module):
86
- def __init__(
87
- self,
88
- input_freq=16000,
89
- resample_freq=16000,
90
- ):
91
- super().__init__()
92
- self.resample = taT.Resample(orig_freq=input_freq, new_freq=resample_freq)
93
- self.spec = taT.MelSpectrogram(sample_rate=resample_freq, n_mels=64, n_fft=1024)
94
- self.to_db = taT.AmplitudeToDB()
95
-
96
- def forward(self, waveform: torch.Tensor) -> torch.Tensor:
97
- if waveform.shape[0] > 1:
98
- waveform = waveform.mean(0, keepdim=True)
99
- waveform = self.resample(waveform)
100
- spectrogram = self.spec(waveform)
101
- spectrogram = self.to_db(spectrogram)
102
-
103
- return spectrogram
104
-
 
82
 
83
  return audio_paths, list(labels)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scrapers/music4dance.py CHANGED
@@ -75,12 +75,6 @@ def get_songs(soup: bs) -> dict:
75
  return songs
76
 
77
 
78
- def download_song(url: str, out_dir: str):
79
- response = requests.get(url)
80
- filename = url.split("/")[-1]
81
- out_file = Path(out_dir, f"{filename}.mp3")
82
- with open(out_file, "wb") as f:
83
- f.write(response.content)
84
 
85
  def scrape_dance_info() -> pd.DataFrame:
86
  js_obj = re.compile(r"{(.|\n)*}")
 
75
  return songs
76
 
77
 
 
 
 
 
 
 
78
 
79
  def scrape_dance_info() -> pd.DataFrame:
80
  js_obj = re.compile(r"{(.|\n)*}")
scrapers/spotify.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spotipy
2
+ from spotipy.oauth2 import SpotifyClientCredentials
3
+ import os
4
+ import json
5
+ from pathlib import Path
6
+ import numpy as np
7
+ import pandas as pd
8
+ from tqdm import tqdm
9
+ from utils import download_song
10
+ import time
11
+
12
+ def set_env():
13
+ here = os.path.dirname(__file__)
14
+ with open(os.path.join(here, "auth", "spotify.json"), "r") as f:
15
+ config = json.load(f)
16
+ os.environ["SPOTIPY_CLIENT_ID"] = config["client_id"]
17
+ os.environ["SPOTIPY_CLIENT_SECRET"] = config["client_secret"]
18
+ os.environ["SPOTIPY_REDIRECT_URI"] = "https://localhost:8080/callback"
19
+
20
+ set_env()
21
+
22
+
23
+ def get_song_preview_url(song_name:str, spotify:spotipy.Spotify, artist:str = None) -> str | None:
24
+ info = {
25
+ "track": song_name
26
+ }
27
+ if artist is not None:
28
+ info["artist"] = artist
29
+ query = " ".join(f"{k}: {v}" for k,v in info.items())
30
+ results = spotify.search(query,type="track", limit=1)["tracks"]["items"]
31
+ valid_results = len(results) > 0 and results[0] is not None and "preview_url" in results[0]
32
+ if not valid_results:
33
+ return None
34
+ song = results[0]
35
+ return song["preview_url"]
36
+
37
+ def patch_missing_songs(
38
+ df: pd.DataFrame,
39
+ ) -> pd.DataFrame:
40
+ spotify = spotipy.Spotify(auth_manager=SpotifyClientCredentials())
41
+ # find songs with missing previews
42
+ audio_urls = df["Sample"].replace(".", np.nan)
43
+ missing_audio = pd.isna(audio_urls)
44
+ missing_df = df[missing_audio]
45
+ def patch_preview(row: pd.Series):
46
+ song:str = row["Title"]
47
+ artist:str = row["Artist"]
48
+ preview_url = get_song_preview_url(song, spotify, artist)
49
+ if preview_url is not None:
50
+ row["Sample"] = preview_url
51
+ return row
52
+ backup_file = open("data/backup_1.csv", "a")
53
+ rows = []
54
+ indices = []
55
+ total_rows = len(missing_df)
56
+ for i, row in tqdm(missing_df.iloc[11121:].iterrows(),total=total_rows):
57
+ patched_row = patch_preview(row)
58
+ backup_file.write(f"{i}, {patched_row['Sample']}\n")
59
+ rows.append(patch_preview(row))
60
+ indices.append(i)
61
+
62
+
63
+ patched_df = pd.DataFrame(rows,index=indices)
64
+ df.update(patched_df)
65
+ return df
66
+
67
+
68
+ def download_links():
69
+ start = 3180
70
+ with open("data/backup_2.csv") as f:
71
+ links = [x.split(",")[1].strip() for x in f.readlines()]
72
+ links = links[start:]
73
+ links = [l for l in links if "https" in l]
74
+ links = links[2680:]
75
+ for link in tqdm(links, "Songs Downloaded"):
76
+ download_song(link, "data/spotify-samples")
77
+ time.sleep(5e-3) # hopefully wont be rate limited with delay 🤞
78
+
79
+
80
+
81
+
82
+ if __name__ == "__main__":
83
+ df = pd.read_csv("data/songs.csv")
84
+ patched = patch_missing_songs(df)
85
+ patched.to_csv("data/last_part.csv")
86
+
87
+
scrapers/utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from pathlib import Path
3
+
4
+ def download_song(url: str, out_dir: str, file_type="mp3"):
5
+ response = requests.get(url)
6
+ filename = url.split("/")[-1]
7
+ out_file = Path(out_dir, f"{filename}.{file_type}")
8
+ with open(out_file, "wb") as f:
9
+ f.write(response.content)