Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
β’
0030bc6
1
Parent(s):
4b8361a
updates
Browse files- .gitattributes +1 -0
- .gitignore +3 -8
- README.md +8 -0
- TODO.md +10 -0
- assets/song-samples/take_it_to_the_limit.wav +0 -3
- assets/song-samples/{alejandro.wav β the_long_day_is_over.wav} +2 -2
- audio_utils.py +42 -0
- environment.yml +3 -0
- models/config/dance-predictor.yaml +2 -8
- models/config/train.yaml +46 -23
- models/residual.py +36 -19
- models/utils.py +33 -9
- models/weights/ResidualDancer/config.json +0 -24
- models/weights/ResidualDancer/dancer_net.pt +0 -3
- assets/song-samples/exs_and_ohs.wav β models/weights/ResidualDancer/weights.ckpt +2 -2
- preprocessing/dataset.py +58 -22
- preprocessing/pipelines.py +87 -41
- preprocessing/preprocess.py +52 -5
- requirements.txt +2 -1
- scrapers/spotify.py +7 -20
- tests.py +22 -0
- train.py +50 -177
.gitattributes
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
*.wav filter=lfs diff=lfs merge=lfs -text
|
2 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
1 |
*.wav filter=lfs diff=lfs merge=lfs -text
|
2 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -1,15 +1,10 @@
|
|
1 |
__pycache__
|
2 |
.DS_Store
|
3 |
-
data
|
4 |
-
data/spotify-samples
|
5 |
-
data/samples-backup.zip
|
6 |
-
data/songs.csv
|
7 |
-
data/songs_original.csv
|
8 |
logs
|
9 |
gradio_cached_examples
|
10 |
explore.ipynb
|
11 |
scrapers/auth
|
12 |
lightning_logs
|
13 |
-
|
14 |
-
|
15 |
-
data/*.zip
|
|
|
1 |
__pycache__
|
2 |
.DS_Store
|
3 |
+
data
|
|
|
|
|
|
|
|
|
4 |
logs
|
5 |
gradio_cached_examples
|
6 |
explore.ipynb
|
7 |
scrapers/auth
|
8 |
lightning_logs
|
9 |
+
.lr_find_*
|
10 |
+
.cache
|
|
README.md
CHANGED
@@ -11,3 +11,11 @@ pinned: false
|
|
11 |
---
|
12 |
|
13 |
# Dance Classifier
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
# Dance Classifier
|
14 |
+
|
15 |
+
Classifies the dance style that best accompanies a provided song. Users record or upload an audio clip and the model provides a list of matching dance styles.
|
16 |
+
|
17 |
+
## Getting Started
|
18 |
+
|
19 |
+
1. Download dependencies: `conda env create --file environment.yml`
|
20 |
+
2. Open environment: `conda activate dancer-net`
|
21 |
+
3. Start the demo application: `python app.py`
|
TODO.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
- β
Ensure app.py audio input sounds like training data
|
2 |
+
- Verify that the training spectrogram matches the predict spectrogram
|
3 |
+
- Count number of example misses in dataset loading
|
4 |
+
- Verify windowing and jitter params in Song Dataset
|
5 |
+
- Create an attention-based network
|
6 |
+
- β
Increase parameter count in network
|
7 |
+
- Verify that labels really match what is on the music4dance site
|
8 |
+
- Read the Medium series about audio DL
|
9 |
+
- double check \_rectify_duration
|
10 |
+
- Filter out songs that have only one vote
|
assets/song-samples/take_it_to_the_limit.wav
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c69e0eeb4321c44daaaaf95dd596b1d813b9f7e9b5ef4ac5ae9fe11878d4b13b
|
3 |
-
size 5292082
|
|
|
|
|
|
|
|
assets/song-samples/{alejandro.wav β the_long_day_is_over.wav}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c8f957921bbd5c322f67748aca228dd7ebf9af005692c57d1050299861883214
|
3 |
+
size 5290062
|
audio_utils.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
from IPython.display import Audio, display
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import torch
|
5 |
+
SAMPLE_RIR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/room-response/rm1/impulse/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo.wav"
|
6 |
+
|
7 |
+
SAMPLE_NOISE_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/distractors/rm1/babb/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav"
|
8 |
+
|
9 |
+
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
|
10 |
+
spec = spec.squeeze(0)
|
11 |
+
spec = spec.numpy()
|
12 |
+
fig, axs = plt.subplots(1, 1)
|
13 |
+
axs.set_title(title or "Spectrogram (db)")
|
14 |
+
axs.set_ylabel(ylabel)
|
15 |
+
axs.set_xlabel("frame")
|
16 |
+
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
|
17 |
+
if xmax:
|
18 |
+
axs.set_xlim((0, xmax))
|
19 |
+
fig.colorbar(im, ax=axs)
|
20 |
+
plt.show(block=False)
|
21 |
+
|
22 |
+
def play_audio(waveform, sample_rate):
|
23 |
+
waveform = waveform.numpy()
|
24 |
+
|
25 |
+
num_channels, num_frames = waveform.shape
|
26 |
+
if num_channels == 1:
|
27 |
+
display(Audio(waveform[0], rate=sample_rate))
|
28 |
+
elif num_channels == 2:
|
29 |
+
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
|
30 |
+
else:
|
31 |
+
raise ValueError("Waveform with more than 2 channels are not supported.")
|
32 |
+
|
33 |
+
def get_rir_sample(path, resample=None, processed=False):
|
34 |
+
rir_raw, sample_rate = torch.load(path)
|
35 |
+
if not processed:
|
36 |
+
return rir_raw, sample_rate
|
37 |
+
rir = rir_raw[:, int(sample_rate*1.01):int(sample_rate*1.3)]
|
38 |
+
rir = rir / torch.norm(rir, p=2)
|
39 |
+
rir = torch.flip(rir, [1])
|
40 |
+
return rir, sample_rate
|
41 |
+
|
42 |
+
|
environment.yml
CHANGED
@@ -15,6 +15,9 @@ dependencies:
|
|
15 |
- requests
|
16 |
- bidict
|
17 |
- tqdm
|
|
|
|
|
18 |
- pip
|
19 |
- gradio
|
|
|
20 |
prefix: /opt/homebrew/Caskroom/miniforge/base/envs/dancer-net
|
|
|
15 |
- requests
|
16 |
- bidict
|
17 |
- tqdm
|
18 |
+
- pytorch-lightning
|
19 |
+
- rich
|
20 |
- pip
|
21 |
- gradio
|
22 |
+
- wakepy
|
23 |
prefix: /opt/homebrew/Caskroom/miniforge/base/envs/dancer-net
|
models/config/dance-predictor.yaml
CHANGED
@@ -1,20 +1,15 @@
|
|
1 |
-
weight_path:
|
2 |
expected_duration: 6
|
3 |
-
threshold: 0.
|
4 |
resample_frequency: 16000
|
5 |
device: cpu
|
6 |
labels:
|
7 |
- Argentine Tango
|
8 |
-
- Balboa
|
9 |
- Bachata
|
10 |
-
- Blues
|
11 |
- Cha Cha
|
12 |
-
- Cumbia
|
13 |
-
- Carolina Shag
|
14 |
- East Coast Swing
|
15 |
- Hustle
|
16 |
- Jive
|
17 |
-
- Lindy Hop
|
18 |
- Quickstep
|
19 |
- Rumba
|
20 |
- Slow Foxtrot
|
@@ -23,4 +18,3 @@ labels:
|
|
23 |
- Slow Waltz
|
24 |
- Tango (Ballroom)
|
25 |
- Viennese Waltz
|
26 |
-
- West Coast Swing
|
|
|
1 |
+
weight_path: models/weights/ResidualDancer/weights.ckpt
|
2 |
expected_duration: 6
|
3 |
+
threshold: 0.4
|
4 |
resample_frequency: 16000
|
5 |
device: cpu
|
6 |
labels:
|
7 |
- Argentine Tango
|
|
|
8 |
- Bachata
|
|
|
9 |
- Cha Cha
|
|
|
|
|
10 |
- East Coast Swing
|
11 |
- Hustle
|
12 |
- Jive
|
|
|
13 |
- Quickstep
|
14 |
- Rumba
|
15 |
- Slow Foxtrot
|
|
|
18 |
- Slow Waltz
|
19 |
- Tango (Ballroom)
|
20 |
- Viennese Waltz
|
|
models/config/train.yaml
CHANGED
@@ -1,23 +1,46 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
global:
|
2 |
+
device: mps
|
3 |
+
seed: 42
|
4 |
+
dance_ids:
|
5 |
+
- ATN
|
6 |
+
- BCH
|
7 |
+
- CHA
|
8 |
+
- ECS
|
9 |
+
- HST
|
10 |
+
- JIV
|
11 |
+
- QST
|
12 |
+
- RMB
|
13 |
+
- SFT
|
14 |
+
- SLS
|
15 |
+
- SMB
|
16 |
+
- SWZ
|
17 |
+
- TGO
|
18 |
+
- VWZ
|
19 |
+
- WCS
|
20 |
+
data_module:
|
21 |
+
batch_size: 1024
|
22 |
+
num_workers: 10
|
23 |
+
min_votes: 2
|
24 |
+
song_data_path: data/songs_cleaned.csv
|
25 |
+
song_audio_path: data/samples
|
26 |
+
dataset_kwargs:
|
27 |
+
audio_window_duration: 6
|
28 |
+
audio_window_jitter: 1.5
|
29 |
+
audio_pipeline_kwargs:
|
30 |
+
mask_count: 0 # Don't mask the data
|
31 |
+
snr_mean: 15.0 # Pretty much eliminate the noise
|
32 |
+
freq_mask_size: 10
|
33 |
+
time_mask_size: 80
|
34 |
+
|
35 |
+
trainer:
|
36 |
+
log_every_n_steps: 15
|
37 |
+
accelerator: gpu
|
38 |
+
max_epochs: 50
|
39 |
+
min_epochs: 5
|
40 |
+
fast_dev_run: False
|
41 |
+
track_grad_norm: 2
|
42 |
+
# gradient_clip_val: 0.5
|
43 |
+
training_environment:
|
44 |
+
learning_rate: 0.0033
|
45 |
+
model:
|
46 |
+
n_channels: 128
|
models/residual.py
CHANGED
@@ -7,7 +7,7 @@ import numpy as np
|
|
7 |
import torchaudio
|
8 |
import yaml
|
9 |
from .utils import calculate_metrics
|
10 |
-
from preprocessing.pipelines import
|
11 |
|
12 |
# Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
|
13 |
|
@@ -15,6 +15,9 @@ class ResidualDancer(nn.Module):
|
|
15 |
def __init__(self,n_channels=128, n_classes=50):
|
16 |
super().__init__()
|
17 |
|
|
|
|
|
|
|
18 |
# Spectrogram
|
19 |
self.spec_bn = nn.BatchNorm2d(1)
|
20 |
|
@@ -33,7 +36,7 @@ class ResidualDancer(nn.Module):
|
|
33 |
self.dense1 = nn.Linear(n_channels*4, n_channels*4)
|
34 |
self.bn = nn.BatchNorm1d(n_channels*4)
|
35 |
self.dense2 = nn.Linear(n_channels*4, n_classes)
|
36 |
-
self.dropout = nn.Dropout(0.
|
37 |
|
38 |
def forward(self, x):
|
39 |
x = self.spec_bn(x)
|
@@ -88,34 +91,51 @@ class ResBlock(nn.Module):
|
|
88 |
|
89 |
class TrainingEnvironment(pl.LightningModule):
|
90 |
|
91 |
-
def __init__(self, model: nn.Module, criterion: nn.Module, learning_rate=1e-4, *args, **kwargs):
|
92 |
super().__init__(*args, **kwargs)
|
93 |
self.model = model
|
94 |
self.criterion = criterion
|
95 |
self.learning_rate = learning_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
def training_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int) -> torch.Tensor:
|
98 |
features, labels = batch
|
99 |
outputs = self.model(features)
|
100 |
loss = self.criterion(outputs, labels)
|
101 |
-
|
102 |
-
self.log_dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
return loss
|
104 |
|
|
|
105 |
def validation_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
106 |
x, y = batch
|
107 |
preds = self.model(x)
|
108 |
-
metrics = calculate_metrics(preds, y, prefix="
|
109 |
-
metrics["
|
110 |
-
self.log_dict(metrics)
|
111 |
|
112 |
def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
113 |
x, y = batch
|
114 |
preds = self.model(x)
|
115 |
-
self.log_dict(calculate_metrics(preds, y, prefix="
|
116 |
|
117 |
def configure_optimizers(self):
|
118 |
-
|
|
|
|
|
119 |
|
120 |
|
121 |
class DancePredictor:
|
@@ -133,7 +153,8 @@ class DancePredictor:
|
|
133 |
self.expected_duration = expected_duration
|
134 |
self.threshold = threshold
|
135 |
self.resample_frequency = resample_frequency
|
136 |
-
self.
|
|
|
137 |
self.labels = np.array(labels)
|
138 |
self.device = device
|
139 |
self.model = self.get_model(weight_path)
|
@@ -155,20 +176,16 @@ class DancePredictor:
|
|
155 |
|
156 |
@torch.no_grad()
|
157 |
def __call__(self, waveform: np.ndarray, sample_rate:int) -> dict[str,float]:
|
158 |
-
|
159 |
-
if min_sample_len > len(waveform):
|
160 |
-
raise Exception("You must record for at least 6 seconds")
|
161 |
-
if len(waveform.shape) > 1 and waveform.shape[1] > 1:
|
162 |
waveform = waveform.transpose(1,0)
|
163 |
-
|
164 |
-
else:
|
165 |
waveform = np.expand_dims(waveform, 0)
|
166 |
-
waveform = waveform[: ,:min_sample_len]
|
167 |
waveform = torch.from_numpy(waveform.astype("int16"))
|
168 |
waveform = torchaudio.functional.apply_codec(waveform,sample_rate, "wav", channels_first=True)
|
169 |
|
170 |
waveform = torchaudio.functional.resample(waveform, sample_rate,self.resample_frequency)
|
171 |
-
|
|
|
172 |
spectrogram = spectrogram.unsqueeze(0).to(self.device)
|
173 |
|
174 |
results = self.model(spectrogram)
|
|
|
7 |
import torchaudio
|
8 |
import yaml
|
9 |
from .utils import calculate_metrics
|
10 |
+
from preprocessing.pipelines import WaveformPreprocessing, AudioToSpectrogram
|
11 |
|
12 |
# Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
|
13 |
|
|
|
15 |
def __init__(self,n_channels=128, n_classes=50):
|
16 |
super().__init__()
|
17 |
|
18 |
+
self.n_channels = n_channels
|
19 |
+
self.n_classes = n_classes
|
20 |
+
|
21 |
# Spectrogram
|
22 |
self.spec_bn = nn.BatchNorm2d(1)
|
23 |
|
|
|
36 |
self.dense1 = nn.Linear(n_channels*4, n_channels*4)
|
37 |
self.bn = nn.BatchNorm1d(n_channels*4)
|
38 |
self.dense2 = nn.Linear(n_channels*4, n_classes)
|
39 |
+
self.dropout = nn.Dropout(0.2)
|
40 |
|
41 |
def forward(self, x):
|
42 |
x = self.spec_bn(x)
|
|
|
91 |
|
92 |
class TrainingEnvironment(pl.LightningModule):
|
93 |
|
94 |
+
def __init__(self, model: nn.Module, criterion: nn.Module, config:dict, learning_rate=1e-4, *args, **kwargs):
|
95 |
super().__init__(*args, **kwargs)
|
96 |
self.model = model
|
97 |
self.criterion = criterion
|
98 |
self.learning_rate = learning_rate
|
99 |
+
self.config=config
|
100 |
+
self.save_hyperparameters({
|
101 |
+
"model": type(model).__name__,
|
102 |
+
"loss": type(criterion).__name__,
|
103 |
+
"config": config,
|
104 |
+
**kwargs
|
105 |
+
})
|
106 |
|
107 |
def training_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int) -> torch.Tensor:
|
108 |
features, labels = batch
|
109 |
outputs = self.model(features)
|
110 |
loss = self.criterion(outputs, labels)
|
111 |
+
metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
|
112 |
+
self.log_dict(metrics, prog_bar=True)
|
113 |
+
# Log spectrograms
|
114 |
+
if batch_index % 100 == 0:
|
115 |
+
tensorboard = self.logger.experiment
|
116 |
+
img_index = torch.randint(0, len(features), (1,)).item()
|
117 |
+
img = features[img_index][0]
|
118 |
+
img = (img - img.min()) / (img.max() - img.min())
|
119 |
+
tensorboard.add_image(f"batch: {batch_index}, element: {img_index}", img, 0, dataformats='HW')
|
120 |
return loss
|
121 |
|
122 |
+
|
123 |
def validation_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
124 |
x, y = batch
|
125 |
preds = self.model(x)
|
126 |
+
metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
|
127 |
+
metrics["val/loss"] = self.criterion(preds, y)
|
128 |
+
self.log_dict(metrics,prog_bar=True)
|
129 |
|
130 |
def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
131 |
x, y = batch
|
132 |
preds = self.model(x)
|
133 |
+
self.log_dict(calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True)
|
134 |
|
135 |
def configure_optimizers(self):
|
136 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
137 |
+
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
|
138 |
+
return [optimizer]
|
139 |
|
140 |
|
141 |
class DancePredictor:
|
|
|
153 |
self.expected_duration = expected_duration
|
154 |
self.threshold = threshold
|
155 |
self.resample_frequency = resample_frequency
|
156 |
+
self.preprocess_waveform = WaveformPreprocessing(resample_frequency * expected_duration)
|
157 |
+
self.audio_to_spectrogram = AudioToSpectrogram(resample_frequency)
|
158 |
self.labels = np.array(labels)
|
159 |
self.device = device
|
160 |
self.model = self.get_model(weight_path)
|
|
|
176 |
|
177 |
@torch.no_grad()
|
178 |
def __call__(self, waveform: np.ndarray, sample_rate:int) -> dict[str,float]:
|
179 |
+
if len(waveform.shape) > 1 and waveform.shape[1] < waveform.shape[0]:
|
|
|
|
|
|
|
180 |
waveform = waveform.transpose(1,0)
|
181 |
+
elif len(waveform.shape) == 1:
|
|
|
182 |
waveform = np.expand_dims(waveform, 0)
|
|
|
183 |
waveform = torch.from_numpy(waveform.astype("int16"))
|
184 |
waveform = torchaudio.functional.apply_codec(waveform,sample_rate, "wav", channels_first=True)
|
185 |
|
186 |
waveform = torchaudio.functional.resample(waveform, sample_rate,self.resample_frequency)
|
187 |
+
waveform = self.preprocess_waveform(waveform)
|
188 |
+
spectrogram = self.audio_to_spectrogram(waveform)
|
189 |
spectrogram = spectrogram.unsqueeze(0).to(self.device)
|
190 |
|
191 |
results = self.model(spectrogram)
|
models/utils.py
CHANGED
@@ -4,6 +4,10 @@ import numpy as np
|
|
4 |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
5 |
|
6 |
class LabelWeightedBCELoss(nn.Module):
|
|
|
|
|
|
|
|
|
7 |
def __init__(self, label_weights:torch.Tensor, reduction="mean"):
|
8 |
super().__init__()
|
9 |
self.label_weights = label_weights
|
@@ -22,17 +26,37 @@ class LabelWeightedBCELoss(nn.Module):
|
|
22 |
return self.reduction(losses)
|
23 |
|
24 |
|
25 |
-
|
|
|
|
|
|
|
26 |
target = target.detach().cpu().numpy()
|
27 |
pred = pred.detach().cpu().numpy()
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
29 |
metrics= {
|
30 |
-
'precision': precision_score(
|
31 |
-
'recall': recall_score(
|
32 |
-
'f1': f1_score(
|
33 |
-
'accuracy': accuracy_score(y_true=
|
34 |
}
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
5 |
|
6 |
class LabelWeightedBCELoss(nn.Module):
|
7 |
+
"""
|
8 |
+
Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
|
9 |
+
Allows for the weighing of each probability distribution wrt loss.
|
10 |
+
"""
|
11 |
def __init__(self, label_weights:torch.Tensor, reduction="mean"):
|
12 |
super().__init__()
|
13 |
self.label_weights = label_weights
|
|
|
26 |
return self.reduction(losses)
|
27 |
|
28 |
|
29 |
+
# TODO: Code a onehot
|
30 |
+
|
31 |
+
|
32 |
+
def calculate_metrics(pred, target, threshold=0.5, prefix="", multi_label=True) -> dict[str, torch.Tensor]:
|
33 |
target = target.detach().cpu().numpy()
|
34 |
pred = pred.detach().cpu().numpy()
|
35 |
+
params = {
|
36 |
+
"y_true": target if multi_label else target.argmax(1) ,
|
37 |
+
"y_pred": np.array(pred > threshold, dtype=float) if multi_label else pred.argmax(1),
|
38 |
+
"zero_division": 0,
|
39 |
+
"average":"macro"
|
40 |
+
}
|
41 |
metrics= {
|
42 |
+
'precision': precision_score(**params),
|
43 |
+
'recall': recall_score(**params),
|
44 |
+
'f1': f1_score(**params),
|
45 |
+
'accuracy': accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
|
46 |
}
|
47 |
+
return {prefix + k: torch.tensor(v,dtype=torch.float32) for k,v in metrics.items()}
|
48 |
+
|
49 |
+
class EarlyStopping:
|
50 |
+
def __init__(self, patience=0):
|
51 |
+
self.patience = patience
|
52 |
+
self.last_measure = np.inf
|
53 |
+
self.consecutive_increase = 0
|
54 |
|
55 |
+
def step(self, val) -> bool:
|
56 |
+
if self.last_measure <= val:
|
57 |
+
self.consecutive_increase +=1
|
58 |
+
else:
|
59 |
+
self.consecutive_increase = 0
|
60 |
+
self.last_measure = val
|
61 |
+
|
62 |
+
return self.patience < self.consecutive_increase
|
models/weights/ResidualDancer/config.json
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"classes": [
|
3 |
-
"ATN",
|
4 |
-
"BBA",
|
5 |
-
"BCH",
|
6 |
-
"BLU",
|
7 |
-
"CHA",
|
8 |
-
"CMB",
|
9 |
-
"CSG",
|
10 |
-
"ECS",
|
11 |
-
"HST",
|
12 |
-
"JIV",
|
13 |
-
"LHP",
|
14 |
-
"QST",
|
15 |
-
"RMB",
|
16 |
-
"SFT",
|
17 |
-
"SLS",
|
18 |
-
"SMB",
|
19 |
-
"SWZ",
|
20 |
-
"TGO",
|
21 |
-
"VWZ",
|
22 |
-
"WCS"
|
23 |
-
]
|
24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/weights/ResidualDancer/dancer_net.pt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:1888558eed82a5d99ac1dab55969a9ea36455d11a9370355d1f2b984598d30ff
|
3 |
-
size 48453416
|
|
|
|
|
|
|
|
assets/song-samples/exs_and_ohs.wav β models/weights/ResidualDancer/weights.ckpt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e107090ff62ac0b79f4f40271e8b1dd6c3d10d8146264ec49df3c8febe99aa23
|
3 |
+
size 193651217
|
preprocessing/dataset.py
CHANGED
@@ -3,87 +3,122 @@ from torch.utils.data import Dataset, DataLoader, random_split
|
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
5 |
import torchaudio as ta
|
6 |
-
from .pipelines import
|
7 |
import pytorch_lightning as pl
|
8 |
from .preprocess import get_examples
|
|
|
9 |
|
10 |
|
11 |
|
12 |
class SongDataset(Dataset):
|
13 |
def __init__(self,
|
14 |
audio_paths: list[str],
|
15 |
-
dance_labels: list[np.ndarray],
|
16 |
audio_duration=30, # seconds
|
17 |
audio_window_duration=6, # seconds
|
|
|
|
|
|
|
18 |
):
|
19 |
assert audio_duration % audio_window_duration == 0, "Audio window should divide duration evenly."
|
|
|
20 |
|
21 |
self.audio_paths = audio_paths
|
22 |
self.dance_labels = dance_labels
|
23 |
audio_info = ta.info(audio_paths[0])
|
24 |
self.sample_rate = audio_info.sample_rate
|
25 |
self.audio_window_duration = int(audio_window_duration)
|
|
|
26 |
self.audio_duration = int(audio_duration)
|
27 |
|
28 |
-
self.audio_pipeline =
|
29 |
|
30 |
def __len__(self):
|
31 |
return len(self.audio_paths) * self.audio_duration // self.audio_window_duration
|
32 |
|
33 |
-
def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
|
34 |
waveform = self._waveform_from_index(idx)
|
|
|
35 |
spectrogram = self.audio_pipeline(waveform)
|
36 |
|
37 |
dance_labels = self._label_from_index(idx)
|
38 |
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
def _waveform_from_index(self, idx:int) -> torch.Tensor:
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
num_frames = self.sample_rate * self.audio_window_duration
|
46 |
-
waveform, sample_rate = ta.load(
|
47 |
assert sample_rate == self.sample_rate, f"Expected sample rate of {self.sample_rate}. Found {sample_rate}"
|
48 |
return waveform
|
49 |
|
50 |
|
51 |
def _label_from_index(self, idx:int) -> torch.Tensor:
|
52 |
-
|
53 |
-
return torch.from_numpy(self.dance_labels[label_idx])
|
54 |
-
|
55 |
|
56 |
class DanceDataModule(pl.LightningDataModule):
|
57 |
def __init__(self,
|
58 |
-
song_data_path="data/
|
59 |
song_audio_path="data/samples",
|
60 |
test_proportion=0.15,
|
61 |
val_proportion=0.1,
|
62 |
target_classes:list[str]=None,
|
|
|
63 |
batch_size:int=64,
|
64 |
-
num_workers=10
|
|
|
65 |
):
|
66 |
super().__init__()
|
67 |
self.song_data_path = song_data_path
|
68 |
self.song_audio_path = song_audio_path
|
69 |
self.val_proportion=val_proportion
|
70 |
self.test_proportion=test_proportion
|
71 |
-
self.
|
72 |
self.target_classes=target_classes
|
73 |
self.batch_size = batch_size
|
74 |
self.num_workers = num_workers
|
|
|
75 |
|
76 |
-
df = pd.read_csv(
|
77 |
-
self.x,self.y = get_examples(df, self.song_audio_path,class_list=self.target_classes)
|
78 |
-
|
79 |
|
80 |
def setup(self, stage: str):
|
81 |
-
|
82 |
-
self.train_ds
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
85 |
def train_dataloader(self):
|
86 |
-
return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers)
|
87 |
|
88 |
def val_dataloader(self):
|
89 |
return DataLoader(self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers)
|
@@ -92,4 +127,5 @@ class DanceDataModule(pl.LightningDataModule):
|
|
92 |
return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers)
|
93 |
|
94 |
def get_label_weights(self):
|
95 |
-
|
|
|
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
5 |
import torchaudio as ta
|
6 |
+
from .pipelines import AudioTrainingPipeline
|
7 |
import pytorch_lightning as pl
|
8 |
from .preprocess import get_examples
|
9 |
+
from sklearn.model_selection import train_test_split
|
10 |
|
11 |
|
12 |
|
13 |
class SongDataset(Dataset):
|
14 |
def __init__(self,
|
15 |
audio_paths: list[str],
|
16 |
+
dance_labels: list[np.ndarray],
|
17 |
audio_duration=30, # seconds
|
18 |
audio_window_duration=6, # seconds
|
19 |
+
audio_window_jitter=0.0, # seconds
|
20 |
+
audio_pipeline_kwargs={},
|
21 |
+
resample_frequency=16000
|
22 |
):
|
23 |
assert audio_duration % audio_window_duration == 0, "Audio window should divide duration evenly."
|
24 |
+
assert audio_window_duration > audio_window_jitter, "Jitter should be a small fraction of the audio window duration."
|
25 |
|
26 |
self.audio_paths = audio_paths
|
27 |
self.dance_labels = dance_labels
|
28 |
audio_info = ta.info(audio_paths[0])
|
29 |
self.sample_rate = audio_info.sample_rate
|
30 |
self.audio_window_duration = int(audio_window_duration)
|
31 |
+
self.audio_window_jitter = audio_window_jitter
|
32 |
self.audio_duration = int(audio_duration)
|
33 |
|
34 |
+
self.audio_pipeline = AudioTrainingPipeline(self.sample_rate, resample_frequency, audio_window_duration, **audio_pipeline_kwargs)
|
35 |
|
36 |
def __len__(self):
|
37 |
return len(self.audio_paths) * self.audio_duration // self.audio_window_duration
|
38 |
|
39 |
+
def __getitem__(self, idx:int) -> tuple[torch.Tensor, torch.Tensor]:
|
40 |
waveform = self._waveform_from_index(idx)
|
41 |
+
assert waveform.shape[1] > 10, f"No data found: {self._backtrace_audio_path(idx)}"
|
42 |
spectrogram = self.audio_pipeline(waveform)
|
43 |
|
44 |
dance_labels = self._label_from_index(idx)
|
45 |
|
46 |
+
example_is_valid = self._validate_output(spectrogram, dance_labels)
|
47 |
+
if example_is_valid:
|
48 |
+
return spectrogram, dance_labels
|
49 |
+
else:
|
50 |
+
# Try the previous one
|
51 |
+
# This happens when some of the audio recordings are really quiet
|
52 |
+
# This WILL NOT leak into other data partitions because songs belong entirely to a partition
|
53 |
+
return self[idx-1]
|
54 |
|
55 |
+
def _convert_idx(self,idx:int) -> int:
|
56 |
+
return idx * self.audio_window_duration // self.audio_duration
|
57 |
+
|
58 |
+
def _backtrace_audio_path(self, index:int) -> str:
|
59 |
+
return self.audio_paths[self._convert_idx(index)]
|
60 |
+
|
61 |
+
def _validate_output(self,x,y):
|
62 |
+
is_finite = not torch.any(torch.isinf(x))
|
63 |
+
is_numerical = not torch.any(torch.isnan(x))
|
64 |
+
has_data = torch.any(x != 0.0)
|
65 |
+
is_binary = len(torch.unique(y)) < 3
|
66 |
+
return all((is_finite,is_numerical, has_data, is_binary))
|
67 |
|
68 |
def _waveform_from_index(self, idx:int) -> torch.Tensor:
|
69 |
+
audio_filepath = self.audio_paths[self._convert_idx(idx)]
|
70 |
+
num_windows = self.audio_duration // self.audio_window_duration
|
71 |
+
frame_index = idx % num_windows
|
72 |
+
jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0
|
73 |
+
jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0
|
74 |
+
jitter = int(torch.FloatTensor(1).uniform_(jitter_start, jitter_end) * self.sample_rate)
|
75 |
+
frame_offset = frame_index * self.audio_window_duration * self.sample_rate + jitter
|
76 |
num_frames = self.sample_rate * self.audio_window_duration
|
77 |
+
waveform, sample_rate = ta.load(audio_filepath, frame_offset=frame_offset, num_frames=num_frames)
|
78 |
assert sample_rate == self.sample_rate, f"Expected sample rate of {self.sample_rate}. Found {sample_rate}"
|
79 |
return waveform
|
80 |
|
81 |
|
82 |
def _label_from_index(self, idx:int) -> torch.Tensor:
|
83 |
+
return torch.from_numpy(self.dance_labels[self._convert_idx(idx)])
|
|
|
|
|
84 |
|
85 |
class DanceDataModule(pl.LightningDataModule):
|
86 |
def __init__(self,
|
87 |
+
song_data_path="data/songs_cleaned.csv",
|
88 |
song_audio_path="data/samples",
|
89 |
test_proportion=0.15,
|
90 |
val_proportion=0.1,
|
91 |
target_classes:list[str]=None,
|
92 |
+
min_votes=1,
|
93 |
batch_size:int=64,
|
94 |
+
num_workers=10,
|
95 |
+
dataset_kwargs={}
|
96 |
):
|
97 |
super().__init__()
|
98 |
self.song_data_path = song_data_path
|
99 |
self.song_audio_path = song_audio_path
|
100 |
self.val_proportion=val_proportion
|
101 |
self.test_proportion=test_proportion
|
102 |
+
self.train_proportion= 1.-test_proportion-val_proportion
|
103 |
self.target_classes=target_classes
|
104 |
self.batch_size = batch_size
|
105 |
self.num_workers = num_workers
|
106 |
+
self.dataset_kwargs = dataset_kwargs
|
107 |
|
108 |
+
df = pd.read_csv(song_data_path)
|
109 |
+
self.x,self.y = get_examples(df, self.song_audio_path,class_list=self.target_classes, multi_label=True, min_votes=min_votes)
|
|
|
110 |
|
111 |
def setup(self, stage: str):
|
112 |
+
train_i, val_i, test_i = random_split(np.arange(len(self.x)), [self.train_proportion, self.val_proportion, self.test_proportion])
|
113 |
+
self.train_ds = self._dataset_from_indices(train_i)
|
114 |
+
self.val_ds = self._dataset_from_indices(val_i)
|
115 |
+
self.test_ds = self._dataset_from_indices(test_i)
|
116 |
+
|
117 |
+
def _dataset_from_indices(self, idx:list[int]) -> SongDataset:
|
118 |
+
return SongDataset(self.x[idx], self.y[idx], **self.dataset_kwargs)
|
119 |
|
|
|
120 |
def train_dataloader(self):
|
121 |
+
return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
|
122 |
|
123 |
def val_dataloader(self):
|
124 |
return DataLoader(self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers)
|
|
|
127 |
return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers)
|
128 |
|
129 |
def get_label_weights(self):
|
130 |
+
n_examples, n_classes = self.y.shape
|
131 |
+
return torch.from_numpy(n_examples / (n_classes * sum(self.y)))
|
preprocessing/pipelines.py
CHANGED
@@ -1,63 +1,109 @@
|
|
1 |
import torch
|
|
|
2 |
from torchaudio import transforms as taT, functional as taF
|
3 |
import torch.nn as nn
|
4 |
|
5 |
-
|
6 |
-
def __init__(
|
7 |
-
self,
|
8 |
-
input_freq=16000,
|
9 |
-
resample_freq=16000,
|
10 |
-
):
|
11 |
-
super().__init__()
|
12 |
-
self.resample = taT.Resample(orig_freq=input_freq, new_freq=resample_freq)
|
13 |
-
self.spec = taT.MelSpectrogram(sample_rate=resample_freq, n_mels=64, n_fft=1024)
|
14 |
-
self.to_db = taT.AmplitudeToDB()
|
15 |
-
|
16 |
-
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
17 |
-
if waveform.shape[0] > 1:
|
18 |
-
waveform = waveform.mean(0, keepdim=True)
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
waveform = self.resample(waveform)
|
23 |
-
spectrogram = self.spec(waveform)
|
24 |
-
spectrogram = self.to_db(spectrogram)
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
taT.TimeStretch(80)
|
37 |
-
)
|
38 |
|
39 |
-
def forward(self, spectrogram:torch.Tensor) -> torch.Tensor:
|
40 |
-
return self.pipeline(spectrogram)
|
41 |
|
|
|
42 |
|
43 |
-
|
44 |
-
def __init__(self):
|
45 |
super().__init__()
|
|
|
46 |
|
47 |
|
48 |
|
49 |
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
|
50 |
-
|
|
|
|
|
51 |
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
super().__init__()
|
56 |
-
self.waveform_aug = WaveformAugmentationPipeline()
|
57 |
-
self.spec_aug = SpectrogramAugmentationPipeline()
|
58 |
-
self.audio_preprocessing = AudioPipeline()
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
import torchaudio
|
3 |
from torchaudio import transforms as taT, functional as taF
|
4 |
import torch.nn as nn
|
5 |
|
6 |
+
NOISE_PATH = "data/augmentation/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
class AudioTrainingPipeline(torch.nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
input_freq=16000,
|
11 |
+
resample_freq=16000,
|
12 |
+
expected_duration=6,
|
13 |
+
freq_mask_size=10,
|
14 |
+
time_mask_size=80,
|
15 |
+
mask_count = 2,
|
16 |
+
snr_mean=6.0):
|
17 |
+
super().__init__()
|
18 |
+
self.input_freq = input_freq
|
19 |
+
self.snr_mean = snr_mean
|
20 |
+
self.mask_count = mask_count
|
21 |
+
self.noise = self.get_noise()
|
22 |
+
self.resample = taT.Resample(input_freq,resample_freq)
|
23 |
+
self.preprocess_waveform = WaveformPreprocessing(resample_freq * expected_duration)
|
24 |
+
self.audio_to_spectrogram = AudioToSpectrogram(
|
25 |
+
sample_rate=resample_freq,
|
26 |
+
)
|
27 |
+
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
|
28 |
+
self.time_mask = taT.TimeMasking(time_mask_size)
|
29 |
|
|
|
|
|
|
|
30 |
|
31 |
+
def get_noise(self) -> torch.Tensor:
|
32 |
+
noise, sr = torchaudio.load(NOISE_PATH)
|
33 |
+
if noise.shape[0] > 1:
|
34 |
+
noise = noise.mean(0, keepdim=True)
|
35 |
+
if sr != self.input_freq:
|
36 |
+
noise = taF.resample(noise,sr, self.input_freq)
|
37 |
+
return noise
|
38 |
|
39 |
+
def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
|
40 |
+
num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
|
41 |
+
noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
|
42 |
+
noise_power = noise.norm(p=2)
|
43 |
+
signal_power = waveform.norm(p=2)
|
44 |
+
snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
|
45 |
+
snr = torch.exp(snr_db / 10)
|
46 |
+
scale = snr * noise_power / signal_power
|
47 |
+
noisy_waveform = (scale * waveform + noise) / 2
|
48 |
+
return noisy_waveform
|
49 |
|
50 |
+
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
|
51 |
+
try:
|
52 |
+
waveform = self.resample(waveform)
|
53 |
+
except:
|
54 |
+
print("oops")
|
55 |
+
waveform = self.preprocess_waveform(waveform)
|
56 |
+
waveform = self.add_noise(waveform)
|
57 |
+
spec = self.audio_to_spectrogram(waveform)
|
58 |
|
59 |
+
# Spectrogram augmentation
|
60 |
+
for _ in range(self.mask_count):
|
61 |
+
spec = self.freq_mask(spec)
|
62 |
+
spec = self.time_mask(spec)
|
63 |
+
return spec
|
|
|
|
|
64 |
|
|
|
|
|
65 |
|
66 |
+
class WaveformPreprocessing(torch.nn.Module):
|
67 |
|
68 |
+
def __init__(self, expected_sample_length:int):
|
|
|
69 |
super().__init__()
|
70 |
+
self.expected_sample_length = expected_sample_length
|
71 |
|
72 |
|
73 |
|
74 |
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
|
75 |
+
# Take out extra channels
|
76 |
+
if waveform.shape[0] > 1:
|
77 |
+
waveform = waveform.mean(0, keepdim=True)
|
78 |
|
79 |
+
# ensure it is the correct length
|
80 |
+
waveform = self._rectify_duration(waveform)
|
81 |
+
return waveform
|
82 |
|
83 |
+
|
84 |
+
def _rectify_duration(self,waveform:torch.Tensor):
|
85 |
+
expected_samples = self.expected_sample_length
|
86 |
+
sample_count = waveform.shape[1]
|
87 |
+
if expected_samples == sample_count:
|
88 |
+
return waveform
|
89 |
+
elif expected_samples > sample_count:
|
90 |
+
pad_amount = expected_samples - sample_count
|
91 |
+
return torch.nn.functional.pad(waveform, (0, pad_amount),mode="constant", value=0.0)
|
92 |
+
else:
|
93 |
+
return waveform[:,:expected_samples]
|
94 |
+
|
95 |
+
|
96 |
+
class AudioToSpectrogram(torch.nn.Module):
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
sample_rate=16000,
|
100 |
+
):
|
101 |
super().__init__()
|
|
|
|
|
|
|
102 |
|
103 |
+
self.spec = taT.MelSpectrogram(sample_rate=sample_rate, n_mels=128, n_fft=1024) # TODO: Change mels to 64
|
104 |
+
self.to_db = taT.AmplitudeToDB()
|
105 |
+
|
106 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
107 |
+
spectrogram = self.spec(waveform)
|
108 |
+
spectrogram = self.to_db(spectrogram)
|
109 |
+
return spectrogram
|
preprocessing/preprocess.py
CHANGED
@@ -4,8 +4,9 @@ import re
|
|
4 |
import json
|
5 |
from pathlib import Path
|
6 |
import os
|
|
|
7 |
import torch
|
8 |
-
|
9 |
|
10 |
def url_to_filename(url:str) -> str:
|
11 |
return f"{url.split('/')[-1]}.wav"
|
@@ -17,6 +18,35 @@ def get_songs_with_audio(df:pd.DataFrame, audio_dir:str) -> pd.DataFrame:
|
|
17 |
df = df[valid_audio]
|
18 |
return df
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def fix_dance_rating_counts(dance_ratings:pd.Series) -> pd.Series:
|
21 |
tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)")
|
22 |
dance_ratings = dance_ratings.apply(lambda v : json.loads(v.replace("'", "\"")))
|
@@ -64,7 +94,7 @@ def vectorize_multi_label(labels: dict[str,int], unique_labels:np.ndarray) -> np
|
|
64 |
probs[probs > 0.0] = 1.0
|
65 |
return probs
|
66 |
|
67 |
-
def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None) -> tuple[
|
68 |
sampled_songs = get_songs_with_audio(df, audio_dir)
|
69 |
sampled_songs.loc[:,"DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
|
70 |
if class_list is not None:
|
@@ -74,11 +104,28 @@ def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None) -> tuple[list[
|
|
74 |
if not pd.isna(labels) and any(label in class_list and amt > 0 for label, amt in labels.items())
|
75 |
else np.nan)
|
76 |
sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
|
77 |
-
|
|
|
|
|
78 |
unique_labels = np.array(get_unique_labels(labels))
|
79 |
-
|
|
|
80 |
|
81 |
audio_paths = [os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]]
|
82 |
|
83 |
-
return audio_paths,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
|
|
4 |
import json
|
5 |
from pathlib import Path
|
6 |
import os
|
7 |
+
import torchaudio
|
8 |
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
|
11 |
def url_to_filename(url:str) -> str:
|
12 |
return f"{url.split('/')[-1]}.wav"
|
|
|
18 |
df = df[valid_audio]
|
19 |
return df
|
20 |
|
21 |
+
def validate_audio(audio_urls:pd.Series, audio_dir:str) -> pd.Series:
|
22 |
+
"""
|
23 |
+
Tests audio urls to ensure that their file exists and the contents is valid.
|
24 |
+
"""
|
25 |
+
audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
|
26 |
+
def is_valid(url):
|
27 |
+
valid_url = type(url) == str and "http" in url
|
28 |
+
if not valid_url:
|
29 |
+
return False
|
30 |
+
filename = url_to_filename(url)
|
31 |
+
if filename not in audio_files:
|
32 |
+
return False
|
33 |
+
try:
|
34 |
+
w, _ = torchaudio.load(os.path.join(audio_dir, filename))
|
35 |
+
except:
|
36 |
+
return False
|
37 |
+
contents_invalid = torch.any(torch.isnan(w)) or torch.any(torch.isinf(w)) or len(torch.unique(w)) <= 2
|
38 |
+
return not contents_invalid
|
39 |
+
|
40 |
+
idxs = []
|
41 |
+
validations = []
|
42 |
+
for index, url in tqdm(audio_urls.items(), total=len(audio_urls), desc="Audio URLs Validated"):
|
43 |
+
idxs.append(index)
|
44 |
+
validations.append(is_valid(url))
|
45 |
+
|
46 |
+
return pd.Series(validations, index=idxs)
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
def fix_dance_rating_counts(dance_ratings:pd.Series) -> pd.Series:
|
51 |
tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)")
|
52 |
dance_ratings = dance_ratings.apply(lambda v : json.loads(v.replace("'", "\"")))
|
|
|
94 |
probs[probs > 0.0] = 1.0
|
95 |
return probs
|
96 |
|
97 |
+
def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None, multi_label=True, min_votes=1) -> tuple[np.ndarray, np.ndarray]:
|
98 |
sampled_songs = get_songs_with_audio(df, audio_dir)
|
99 |
sampled_songs.loc[:,"DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
|
100 |
if class_list is not None:
|
|
|
104 |
if not pd.isna(labels) and any(label in class_list and amt > 0 for label, amt in labels.items())
|
105 |
else np.nan)
|
106 |
sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
|
107 |
+
vote_mask = sampled_songs["DanceRating"].apply(lambda dances: any(votes >= min_votes for votes in dances.values()))
|
108 |
+
sampled_songs = sampled_songs[vote_mask]
|
109 |
+
labels = sampled_songs["DanceRating"].apply(lambda dances : {dance: votes for dance, votes in dances.items() if votes >= min_votes})
|
110 |
unique_labels = np.array(get_unique_labels(labels))
|
111 |
+
vectorizer = vectorize_multi_label if multi_label else vectorize_label_probs
|
112 |
+
labels = labels.apply(lambda i : vectorizer(i, unique_labels))
|
113 |
|
114 |
audio_paths = [os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]]
|
115 |
|
116 |
+
return np.array(audio_paths), np.stack(labels)
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
links = pd.read_csv("data/backup_2.csv", index_col="index")
|
121 |
+
df = pd.read_csv("data/songs.csv")
|
122 |
+
l = links["link"].str.strip()
|
123 |
+
l = l.apply(lambda url : url if "http" in url else np.nan)
|
124 |
+
l = l.dropna()
|
125 |
+
df["Sample"].update(l)
|
126 |
+
addna = lambda url : url if type(url) == str and "http" in url else np.nan
|
127 |
+
df["Sample"] = df["Sample"].apply(addna)
|
128 |
+
is_valid = validate_audio(df["Sample"],"data/samples")
|
129 |
+
df["valid"] = is_valid
|
130 |
+
df.to_csv("data/songs_validated.csv")
|
131 |
|
requirements.txt
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
torchvision
|
2 |
torch
|
|
|
3 |
torchaudio
|
|
|
4 |
numpy
|
5 |
pandas
|
6 |
seaborn
|
|
|
|
|
1 |
torch
|
2 |
+
torchvision
|
3 |
torchaudio
|
4 |
+
pytorch-lightning
|
5 |
numpy
|
6 |
pandas
|
7 |
seaborn
|
scrapers/spotify.py
CHANGED
@@ -49,14 +49,14 @@ def patch_missing_songs(
|
|
49 |
if preview_url is not None:
|
50 |
row["Sample"] = preview_url
|
51 |
return row
|
52 |
-
backup_file = open("data/backup_1.csv", "a")
|
53 |
rows = []
|
54 |
indices = []
|
|
|
|
|
55 |
total_rows = len(missing_df)
|
56 |
-
for i, row in tqdm(missing_df.
|
57 |
patched_row = patch_preview(row)
|
58 |
-
|
59 |
-
rows.append(patch_preview(row))
|
60 |
indices.append(i)
|
61 |
|
62 |
|
@@ -65,23 +65,10 @@ def patch_missing_songs(
|
|
65 |
return df
|
66 |
|
67 |
|
68 |
-
def
|
69 |
-
|
70 |
-
with open("data/backup_2.csv") as f:
|
71 |
links = [x.split(",")[1].strip() for x in f.readlines()]
|
72 |
-
links = links[start:]
|
73 |
links = [l for l in links if "https" in l]
|
74 |
-
links = links[2680:]
|
75 |
for link in tqdm(links, "Songs Downloaded"):
|
76 |
-
download_song(link,
|
77 |
time.sleep(5e-3) # hopefully wont be rate limited with delay π€
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
if __name__ == "__main__":
|
83 |
-
df = pd.read_csv("data/songs.csv")
|
84 |
-
patched = patch_missing_songs(df)
|
85 |
-
patched.to_csv("data/last_part.csv")
|
86 |
-
|
87 |
-
|
|
|
49 |
if preview_url is not None:
|
50 |
row["Sample"] = preview_url
|
51 |
return row
|
|
|
52 |
rows = []
|
53 |
indices = []
|
54 |
+
after = 18418
|
55 |
+
missing_df = missing_df.iloc[after:]
|
56 |
total_rows = len(missing_df)
|
57 |
+
for i, row in tqdm(missing_df.iterrows(),total=total_rows):
|
58 |
patched_row = patch_preview(row)
|
59 |
+
rows.append(patched_row)
|
|
|
60 |
indices.append(i)
|
61 |
|
62 |
|
|
|
65 |
return df
|
66 |
|
67 |
|
68 |
+
def download_links_from_backup(backup_file:str, output_dir:str):
|
69 |
+
with open(backup_file) as f:
|
|
|
70 |
links = [x.split(",")[1].strip() for x in f.readlines()]
|
|
|
71 |
links = [l for l in links if "https" in l]
|
|
|
72 |
for link in tqdm(links, "Songs Downloaded"):
|
73 |
+
download_song(link, output_dir)
|
74 |
time.sleep(5e-3) # hopefully wont be rate limited with delay π€
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchaudio
|
2 |
+
import numpy as np
|
3 |
+
from audio_utils import play_audio
|
4 |
+
from preprocessing.dataset import SongDataset
|
5 |
+
|
6 |
+
def test_audio_splitting():
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
audio_paths = ["data/samples/95f2df65f7450db3b1af29aa77ba7edc6ab52075?cid=7ffadeb2e136495fb5a62d1ac9be8f62.wav"]
|
11 |
+
labels = [np.array([1,0,1,0])]
|
12 |
+
whole_song, sr = torchaudio.load("data/samples/95f2df65f7450db3b1af29aa77ba7edc6ab52075?cid=7ffadeb2e136495fb5a62d1ac9be8f62.wav")
|
13 |
+
|
14 |
+
ds = SongDataset(audio_paths, labels)
|
15 |
+
song_parts = (ds._waveform_from_index(i) for i in range(len(ds)))
|
16 |
+
print("Sample Parts")
|
17 |
+
for part in song_parts:
|
18 |
+
play_audio(part,sr)
|
19 |
+
|
20 |
+
|
21 |
+
print("Whole Sample")
|
22 |
+
play_audio(whole_song,sr)
|
train.py
CHANGED
@@ -1,196 +1,69 @@
|
|
1 |
-
import datetime
|
2 |
-
import os
|
3 |
-
import torch
|
4 |
from torch.utils.data import DataLoader
|
5 |
-
import torch.nn as nn
|
6 |
-
from tqdm import tqdm
|
7 |
import pandas as pd
|
8 |
-
|
9 |
-
from torch.utils.data import
|
10 |
-
import json
|
11 |
from sklearn.model_selection import KFold
|
12 |
-
|
13 |
-
from
|
|
|
14 |
from preprocessing.dataset import SongDataset
|
15 |
from preprocessing.preprocess import get_examples
|
16 |
-
from models.residual import ResidualDancer
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
TARGET_CLASSES = ['ATN',
|
21 |
-
'BBA',
|
22 |
-
'BCH',
|
23 |
-
'BLU',
|
24 |
-
'CHA',
|
25 |
-
'CMB',
|
26 |
-
'CSG',
|
27 |
-
'ECS',
|
28 |
-
'HST',
|
29 |
-
'JIV',
|
30 |
-
'LHP',
|
31 |
-
'QST',
|
32 |
-
'RMB',
|
33 |
-
'SFT',
|
34 |
-
'SLS',
|
35 |
-
'SMB',
|
36 |
-
'SWZ',
|
37 |
-
'TGO',
|
38 |
-
'VWZ',
|
39 |
-
'WCS']
|
40 |
-
|
41 |
-
def get_timestamp() -> str:
|
42 |
-
return datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
|
43 |
-
|
44 |
-
class EarlyStopping:
|
45 |
-
def __init__(self, patience=0):
|
46 |
-
self.patience = patience
|
47 |
-
self.last_measure = np.inf
|
48 |
-
self.consecutive_increase = 0
|
49 |
-
|
50 |
-
def step(self, val) -> bool:
|
51 |
-
if self.last_measure <= val:
|
52 |
-
self.consecutive_increase +=1
|
53 |
-
else:
|
54 |
-
self.consecutive_increase = 0
|
55 |
-
self.last_measure = val
|
56 |
-
|
57 |
-
return self.patience < self.consecutive_increase
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
def calculate_metrics(pred, target, threshold=0.5, prefix=""):
|
62 |
-
target = target.detach().cpu().numpy()
|
63 |
-
pred = pred.detach().cpu().numpy()
|
64 |
-
pred = np.array(pred > threshold, dtype=float)
|
65 |
-
metrics= {
|
66 |
-
'precision': precision_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
|
67 |
-
'recall': recall_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
|
68 |
-
'f1': f1_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
|
69 |
-
'accuracy': accuracy_score(y_true=target, y_pred=pred),
|
70 |
-
}
|
71 |
-
if prefix != "":
|
72 |
-
metrics = {prefix + k : v for k, v in metrics.items()}
|
73 |
-
|
74 |
-
return metrics
|
75 |
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
def
|
78 |
-
val_metrics = []
|
79 |
-
for features, labels in (prog_bar := tqdm(data_loader)):
|
80 |
-
features = features.to(device)
|
81 |
-
labels = labels.to(device)
|
82 |
-
with torch.no_grad():
|
83 |
-
outputs = model(features)
|
84 |
-
loss = criterion(outputs, labels)
|
85 |
-
batch_metrics = calculate_metrics(outputs, labels, prefix="val_")
|
86 |
-
batch_metrics["val_loss"] = loss.item()
|
87 |
-
prog_bar.set_description(f'Validation - Loss: {batch_metrics["val_loss"]:.2f}, Accuracy: {batch_metrics["val_accuracy"]:.2f}')
|
88 |
-
val_metrics.append(batch_metrics)
|
89 |
-
return pd.DataFrame(val_metrics).mean()
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
def train(
|
94 |
-
model: nn.Module,
|
95 |
-
data_loader: DataLoader,
|
96 |
-
val_loader=None,
|
97 |
-
epochs=3,
|
98 |
-
lr=1e-3,
|
99 |
-
device="mps"):
|
100 |
-
criterion = nn.BCELoss()
|
101 |
-
optimizer = torch.optim.Adam(model.parameters(),lr=lr)
|
102 |
-
early_stop = EarlyStopping(1)
|
103 |
-
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr,
|
104 |
-
steps_per_epoch=int(len(data_loader)),
|
105 |
-
epochs=epochs,
|
106 |
-
anneal_strategy='linear')
|
107 |
-
metrics = []
|
108 |
-
for epoch in range(1,epochs+1):
|
109 |
-
train_metrics = []
|
110 |
-
prog_bar = tqdm(data_loader)
|
111 |
-
for features, labels in prog_bar:
|
112 |
-
features = features.to(device)
|
113 |
-
labels = labels.to(device)
|
114 |
-
optimizer.zero_grad()
|
115 |
-
outputs = model(features)
|
116 |
-
loss = criterion(outputs, labels)
|
117 |
-
loss.backward()
|
118 |
-
optimizer.step()
|
119 |
-
scheduler.step()
|
120 |
-
batch_metrics = calculate_metrics(outputs, labels)
|
121 |
-
batch_metrics["loss"] = loss.item()
|
122 |
-
train_metrics.append(batch_metrics)
|
123 |
-
prog_bar.set_description(f'Training - Epoch: {epoch}/{epochs}, Loss: {batch_metrics["loss"]:.2f}, Accuracy: {batch_metrics["accuracy"]:.2f}')
|
124 |
-
train_metrics = pd.DataFrame(train_metrics).mean()
|
125 |
-
if val_loader is not None:
|
126 |
-
val_metrics = evaluate(model, val_loader, criterion)
|
127 |
-
if early_stop.step(val_metrics["val_f1"]):
|
128 |
-
break
|
129 |
-
epoch_metrics = pd.concat([train_metrics, val_metrics], axis=0)
|
130 |
-
else:
|
131 |
-
epoch_metrics = train_metrics
|
132 |
-
metrics.append(dict(epoch_metrics))
|
133 |
-
|
134 |
-
return model, metrics
|
135 |
-
|
136 |
-
|
137 |
-
def cross_validation(seed=42, batch_size=64, k=5, device="mps"):
|
138 |
df = pd.read_csv("data/songs.csv")
|
139 |
-
|
140 |
-
|
|
|
141 |
dataset = SongDataset(x,y)
|
142 |
-
splits=KFold(n_splits=k,shuffle=True,random_state=seed)
|
143 |
-
|
144 |
for fold, (train_idx,val_idx) in enumerate(splits.split(x,y)):
|
145 |
print(f"Fold {fold+1}")
|
146 |
-
|
|
|
147 |
train_sampler = SubsetRandomSampler(train_idx)
|
148 |
test_sampler = SubsetRandomSampler(val_idx)
|
149 |
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
|
150 |
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
"
|
159 |
-
)
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
-
def train_model():
|
167 |
|
168 |
-
df = pd.read_csv("data/songs.csv")
|
169 |
-
x,y = get_examples(df, "data/samples",class_list=TARGET_CLASSES)
|
170 |
-
dataset = SongDataset(x,y)
|
171 |
-
train_count = int(len(dataset) * 0.9)
|
172 |
-
datasets = random_split(dataset, [train_count, len(dataset) - train_count], torch.Generator().manual_seed(SEED))
|
173 |
-
data_loaders = [DataLoader(data, batch_size=64, shuffle=True) for data in datasets]
|
174 |
-
train_data, val_data = data_loaders
|
175 |
-
example_spec, example_label = dataset[0]
|
176 |
-
n_classes = len(example_label)
|
177 |
-
model = ResidualDancer(n_classes=n_classes).to(DEVICE)
|
178 |
-
model, metrics = train(model,train_data, val_data, epochs=3, device=DEVICE)
|
179 |
-
|
180 |
-
log_dir = os.path.join(
|
181 |
-
"logs", get_timestamp()
|
182 |
-
)
|
183 |
-
os.makedirs(log_dir, exist_ok=True)
|
184 |
-
|
185 |
-
torch.save(model.state_dict(), os.path.join(log_dir, "residual_dancer.pt"))
|
186 |
-
metrics = pd.DataFrame(metrics)
|
187 |
-
metrics.to_csv(os.path.join(log_dir, "metrics.csv"))
|
188 |
-
config = {
|
189 |
-
"classes": TARGET_CLASSES
|
190 |
-
}
|
191 |
-
with open(os.path.join(log_dir, "config.json")) as f:
|
192 |
-
json.dump(config, f)
|
193 |
-
print("Training information saved!")
|
194 |
|
195 |
if __name__ == "__main__":
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from torch.utils.data import DataLoader
|
|
|
|
|
2 |
import pandas as pd
|
3 |
+
from torch import nn
|
4 |
+
from torch.utils.data import SubsetRandomSampler
|
|
|
5 |
from sklearn.model_selection import KFold
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
from pytorch_lightning import callbacks as cb
|
8 |
+
from models.utils import LabelWeightedBCELoss
|
9 |
from preprocessing.dataset import SongDataset
|
10 |
from preprocessing.preprocess import get_examples
|
11 |
+
from models.residual import ResidualDancer, TrainingEnvironment
|
12 |
+
import yaml
|
13 |
+
from preprocessing.dataset import DanceDataModule
|
14 |
+
from wakepy import keepawake
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
def get_config(filepath:str) -> dict:
|
17 |
+
with open(filepath, "r") as f:
|
18 |
+
config = yaml.safe_load(f)
|
19 |
+
return config
|
20 |
|
21 |
+
def cross_validation(config, k=5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
df = pd.read_csv("data/songs.csv")
|
23 |
+
g_config = config["global"]
|
24 |
+
batch_size = config["data_module"]["batch_size"]
|
25 |
+
x,y = get_examples(df, "data/samples",class_list=g_config["dance_ids"])
|
26 |
dataset = SongDataset(x,y)
|
27 |
+
splits=KFold(n_splits=k,shuffle=True,random_state=g_config["seed"])
|
28 |
+
trainer = pl.Trainer(accelerator=g_config["device"])
|
29 |
for fold, (train_idx,val_idx) in enumerate(splits.split(x,y)):
|
30 |
print(f"Fold {fold+1}")
|
31 |
+
model = ResidualDancer(n_classes=len(g_config["dance_ids"]))
|
32 |
+
train_env = TrainingEnvironment(model,nn.BCELoss())
|
33 |
train_sampler = SubsetRandomSampler(train_idx)
|
34 |
test_sampler = SubsetRandomSampler(val_idx)
|
35 |
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
|
36 |
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
|
37 |
+
trainer.fit(train_env, train_loader)
|
38 |
+
trainer.test(train_env, test_loader)
|
39 |
+
|
40 |
+
|
41 |
+
def train_model(config:dict):
|
42 |
+
TARGET_CLASSES = config["global"]["dance_ids"]
|
43 |
+
DEVICE = config["global"]["device"]
|
44 |
+
SEED = config["global"]["seed"]
|
45 |
+
pl.seed_everything(SEED, workers=True)
|
46 |
+
data = DanceDataModule(target_classes=TARGET_CLASSES, **config['data_module'])
|
47 |
+
model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config['model'])
|
48 |
+
label_weights = data.get_label_weights().to(DEVICE)
|
49 |
+
criterion = LabelWeightedBCELoss(label_weights) #nn.CrossEntropyLoss(label_weights)
|
50 |
+
train_env = TrainingEnvironment(model, criterion, config)
|
51 |
+
callbacks = [
|
52 |
+
# cb.LearningRateFinder(update_attr=True),
|
53 |
+
cb.EarlyStopping("val/loss", patience=5),
|
54 |
+
cb.StochasticWeightAveraging(1e-2),
|
55 |
+
cb.RichProgressBar()
|
56 |
+
]
|
57 |
+
trainer = pl.Trainer(
|
58 |
+
callbacks=callbacks,
|
59 |
+
**config["trainer"]
|
60 |
+
)
|
61 |
+
trainer.fit(train_env, datamodule=data)
|
62 |
+
trainer.test(train_env, datamodule=data)
|
63 |
|
|
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
if __name__ == "__main__":
|
67 |
+
config = get_config("models/config/train.yaml")
|
68 |
+
with keepawake():
|
69 |
+
train_model(config)
|