waidhoferj commited on
Commit
7b37b0e
1 Parent(s): 5649272

added model weights

Browse files
.gitattributes CHANGED
@@ -1 +1,2 @@
1
  *.wav filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.wav filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,5 +1,8 @@
1
  __pycache__
2
  .DS_Store
3
- data
 
 
 
4
  logs
5
  gradio_cached_examples
 
1
  __pycache__
2
  .DS_Store
3
+ data/samples
4
+ data/samples-backup.zip
5
+ data/samples-backup.zip
6
+ data/songs.csv
7
  logs
8
  gradio_cached_examples
app.py CHANGED
@@ -4,15 +4,17 @@ import numpy as np
4
  import torch
5
  from preprocessing.preprocess import AudioPipeline
6
  from preprocessing.preprocess import AudioPipeline
7
- from dancer_net.dancer_net import ShortChunkCNN
8
  import os
9
  import json
10
  from functools import cache
11
  import pandas as pd
12
 
 
 
13
  @cache
14
- def get_model(device) -> tuple[ShortChunkCNN, np.ndarray]:
15
- model_path = "logs/20221226-230930"
16
  weights = os.path.join(model_path, "dancer_net.pt")
17
  config_path = os.path.join(model_path, "config.json")
18
 
@@ -20,7 +22,7 @@ def get_model(device) -> tuple[ShortChunkCNN, np.ndarray]:
20
  config = json.load(f)
21
  labels = np.array(sorted(config["classes"]))
22
 
23
- model = ShortChunkCNN(n_class=len(labels))
24
  model.load_state_dict(torch.load(weights))
25
  model = model.to(device).eval()
26
  return model, labels
 
4
  import torch
5
  from preprocessing.preprocess import AudioPipeline
6
  from preprocessing.preprocess import AudioPipeline
7
+ from models.residual import ResidualDancer
8
  import os
9
  import json
10
  from functools import cache
11
  import pandas as pd
12
 
13
+
14
+
15
  @cache
16
+ def get_model(device) -> tuple[ResidualDancer, np.ndarray]:
17
+ model_path = "models/weights/ResidualDancer"
18
  weights = os.path.join(model_path, "dancer_net.pt")
19
  config_path = os.path.join(model_path, "config.json")
20
 
 
22
  config = json.load(f)
23
  labels = np.array(sorted(config["classes"]))
24
 
25
+ model = ResidualDancer(n_classes=len(labels))
26
  model.load_state_dict(torch.load(weights))
27
  model = model.to(device).eval()
28
  return model, labels
data/dance_mapping.csv ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ id,name
2
+ SWZ,Slow Waltz
3
+ CSW,Cross-step Waltz
4
+ CFT,Castle Foxtrot
5
+ SFT,Slow Foxtrot
6
+ TGO,Tango (Ballroom)
7
+ PBD,Peabody
8
+ VWZ,Viennese Waltz
9
+ QST,Quickstep
10
+ BOL,Bolero
11
+ CHA,Cha Cha
12
+ MBO,Mambo
13
+ JIV,Jive
14
+ RMB,Rumba
15
+ ECS,East Coast Swing
16
+ WCS,West Coast Swing
17
+ HST,Hustle
18
+ MRG,Merengue
19
+ PDL,Paso Doble
20
+ SMB,Samba
21
+ PLK,Polka
22
+ SLS,Salsa
23
+ BCH,Bachata
24
+ NC2,Night Club Two Step
25
+ C2S,Country Two Step
26
+ CMB,Cumbia
27
+ LHP,Lindy Hop
28
+ CST,Charleston
29
+ CSG,Carolina Shag
30
+ CLS,Collegiate Shag
31
+ ATN,Argentine Tango
32
+ TGV,Tango Vals
33
+ NTN,Neo Tango
34
+ MGA,Milonga
35
+ BSN,Bossa Nova
36
+ JSW,Jump Swing
37
+ BLU,Blues
38
+ MWT,Motown
39
+ BBA,Balboa
40
+ JAZ,Jazz
41
+ CNT,Contemporary
42
+ BLT,Ballet
43
+ BDW,Broadway
44
+ TAP,Tap
45
+ HHP,Hip-Hop
46
+ BWD,Bollywood
47
+ DSC,Disco
48
+ FST,Freestyle
main.py DELETED
@@ -1,46 +0,0 @@
1
- import torchaudio
2
- from preprocessing.preprocess import AudioPipeline
3
- from dancer_net.dancer_net import ShortChunkCNN
4
- import torch
5
- import numpy as np
6
- import os
7
- import json
8
-
9
- if __name__ == "__main__":
10
-
11
- audio_file = "data/samples/mzm.iqskzxzx.aac.p.m4a.wav"
12
- seconds = 6
13
- model_path = "logs/20221226-230930"
14
- weights = os.path.join(model_path, "dancer_net.pt")
15
- config_path = os.path.join(model_path, "config.json")
16
- device = "mps"
17
- threshold = 0.5
18
-
19
- with open(config_path) as f:
20
- config = json.load(f)
21
- labels = np.array(sorted(config["classes"]))
22
-
23
- audio_pipeline = AudioPipeline()
24
- waveform, sample_rate = torchaudio.load(audio_file)
25
- waveform = waveform[:, :seconds * sample_rate]
26
- spectrogram = audio_pipeline(waveform)
27
- spectrogram = spectrogram.unsqueeze(0).to(device)
28
-
29
- model = ShortChunkCNN(n_class=len(labels))
30
- model.load_state_dict(torch.load(weights))
31
- model = model.to(device).eval()
32
-
33
- with torch.no_grad():
34
- results = model(spectrogram)
35
- results = results.squeeze(0).detach().cpu().numpy()
36
- results = results > threshold
37
- results = labels[results]
38
- print(results)
39
-
40
-
41
-
42
-
43
-
44
-
45
-
46
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dancer_net/dancer_net.py → models/residual.py RENAMED
@@ -1,16 +1,12 @@
1
- import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- from torchaudio import transforms as taT, functional as taF
5
 
 
6
 
7
-
8
- DEVICE = "mps"
9
- class ShortChunkCNN(nn.Module):
10
  def __init__(self,
11
  n_channels=128,
12
- sample_rate=16000,
13
- n_class=50):
14
  super().__init__()
15
 
16
  # Spectrogram
@@ -18,19 +14,19 @@ class ShortChunkCNN(nn.Module):
18
 
19
  # CNN
20
  self.res_layers = nn.Sequential(
21
- Res_2d(1, n_channels, stride=2),
22
- Res_2d(n_channels, n_channels, stride=2),
23
- Res_2d(n_channels, n_channels*2, stride=2),
24
- Res_2d(n_channels*2, n_channels*2, stride=2),
25
- Res_2d(n_channels*2, n_channels*2, stride=2),
26
- Res_2d(n_channels*2, n_channels*2, stride=2),
27
- Res_2d(n_channels*2, n_channels*4, stride=2)
28
  )
29
 
30
  # Dense
31
  self.dense1 = nn.Linear(n_channels*4, n_channels*4)
32
  self.bn = nn.BatchNorm1d(n_channels*4)
33
- self.dense2 = nn.Linear(n_channels*4, n_class)
34
  self.dropout = nn.Dropout(0.3)
35
 
36
  def forward(self, x):
@@ -56,7 +52,7 @@ class ShortChunkCNN(nn.Module):
56
  return x
57
 
58
 
59
- class Res_2d(nn.Module):
60
  def __init__(self, input_channels, output_channels, shape=3, stride=2):
61
  super().__init__()
62
  # convolution
 
 
1
  import torch.nn as nn
2
  import torch.nn.functional as F
 
3
 
4
+ # Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
5
 
6
+ class ResidualDancer(nn.Module):
 
 
7
  def __init__(self,
8
  n_channels=128,
9
+ n_classes=50):
 
10
  super().__init__()
11
 
12
  # Spectrogram
 
14
 
15
  # CNN
16
  self.res_layers = nn.Sequential(
17
+ ResBlock(1, n_channels, stride=2),
18
+ ResBlock(n_channels, n_channels, stride=2),
19
+ ResBlock(n_channels, n_channels*2, stride=2),
20
+ ResBlock(n_channels*2, n_channels*2, stride=2),
21
+ ResBlock(n_channels*2, n_channels*2, stride=2),
22
+ ResBlock(n_channels*2, n_channels*2, stride=2),
23
+ ResBlock(n_channels*2, n_channels*4, stride=2)
24
  )
25
 
26
  # Dense
27
  self.dense1 = nn.Linear(n_channels*4, n_channels*4)
28
  self.bn = nn.BatchNorm1d(n_channels*4)
29
+ self.dense2 = nn.Linear(n_channels*4, n_classes)
30
  self.dropout = nn.Dropout(0.3)
31
 
32
  def forward(self, x):
 
52
  return x
53
 
54
 
55
+ class ResBlock(nn.Module):
56
  def __init__(self, input_channels, output_channels, shape=3, stride=2):
57
  super().__init__()
58
  # convolution
models/weights/ResidualDancer/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "classes": [
3
+ "ATN",
4
+ "BBA",
5
+ "BCH",
6
+ "BLU",
7
+ "CHA",
8
+ "CMB",
9
+ "CSG",
10
+ "ECS",
11
+ "HST",
12
+ "JIV",
13
+ "LHP",
14
+ "QST",
15
+ "RMB",
16
+ "SFT",
17
+ "SLS",
18
+ "SMB",
19
+ "SWZ",
20
+ "TGO",
21
+ "VWZ",
22
+ "WCS"
23
+ ]
24
+ }
models/weights/ResidualDancer/dancer_net.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1888558eed82a5d99ac1dab55969a9ea36455d11a9370355d1f2b984598d30ff
3
+ size 48453416
train.py CHANGED
@@ -13,10 +13,30 @@ from sklearn.model_selection import KFold
13
  from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
14
  from preprocessing.dataset import SongDataset
15
  from preprocessing.preprocess import get_examples
16
- from dancer_net.dancer_net import ShortChunkCNN
17
 
18
  DEVICE = "mps"
19
  SEED = 42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def get_timestamp() -> str:
22
  return datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
@@ -115,28 +135,8 @@ def train(
115
 
116
 
117
  def cross_validation(seed=42, batch_size=64, k=5, device="mps"):
118
- target_classes = ['ATN',
119
- 'BBA',
120
- 'BCH',
121
- 'BLU',
122
- 'CHA',
123
- 'CMB',
124
- 'CSG',
125
- 'ECS',
126
- 'HST',
127
- 'JIV',
128
- 'LHP',
129
- 'QST',
130
- 'RMB',
131
- 'SFT',
132
- 'SLS',
133
- 'SMB',
134
- 'SWZ',
135
- 'TGO',
136
- 'VWZ',
137
- 'WCS']
138
  df = pd.read_csv("data/songs.csv")
139
- x,y = get_examples(df, "data/samples",class_list=target_classes)
140
 
141
  dataset = SongDataset(x,y)
142
  splits=KFold(n_splits=k,shuffle=True,random_state=seed)
@@ -149,7 +149,7 @@ def cross_validation(seed=42, batch_size=64, k=5, device="mps"):
149
  train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
150
  test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
151
  n_classes = len(y[0])
152
- model = ShortChunkCNN(n_class=n_classes).to(device)
153
  model, _ = train(model,train_loader, epochs=2, device=device)
154
  val_metrics = evaluate(model, test_loader, nn.BCELoss())
155
  metrics.append(val_metrics)
@@ -164,28 +164,9 @@ def cross_validation(seed=42, batch_size=64, k=5, device="mps"):
164
 
165
 
166
  def train_model():
167
- target_classes = ['ATN',
168
- 'BBA',
169
- 'BCH',
170
- 'BLU',
171
- 'CHA',
172
- 'CMB',
173
- 'CSG',
174
- 'ECS',
175
- 'HST',
176
- 'JIV',
177
- 'LHP',
178
- 'QST',
179
- 'RMB',
180
- 'SFT',
181
- 'SLS',
182
- 'SMB',
183
- 'SWZ',
184
- 'TGO',
185
- 'VWZ',
186
- 'WCS']
187
  df = pd.read_csv("data/songs.csv")
188
- x,y = get_examples(df, "data/samples",class_list=target_classes)
189
  dataset = SongDataset(x,y)
190
  train_count = int(len(dataset) * 0.9)
191
  datasets = random_split(dataset, [train_count, len(dataset) - train_count], torch.Generator().manual_seed(SEED))
@@ -193,7 +174,7 @@ def train_model():
193
  train_data, val_data = data_loaders
194
  example_spec, example_label = dataset[0]
195
  n_classes = len(example_label)
196
- model = ShortChunkCNN(n_class=n_classes).to(DEVICE)
197
  model, metrics = train(model,train_data, val_data, epochs=3, device=DEVICE)
198
 
199
  log_dir = os.path.join(
@@ -201,11 +182,11 @@ def train_model():
201
  )
202
  os.makedirs(log_dir, exist_ok=True)
203
 
204
- torch.save(model.state_dict(), os.path.join(log_dir, "dancer_net.pt"))
205
  metrics = pd.DataFrame(metrics)
206
  metrics.to_csv(os.path.join(log_dir, "metrics.csv"))
207
  config = {
208
- "classes": target_classes
209
  }
210
  with open(os.path.join(log_dir, "config.json")) as f:
211
  json.dump(config, f)
 
13
  from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
14
  from preprocessing.dataset import SongDataset
15
  from preprocessing.preprocess import get_examples
16
+ from models.residual import ResidualDancer
17
 
18
  DEVICE = "mps"
19
  SEED = 42
20
+ TARGET_CLASSES = ['ATN',
21
+ 'BBA',
22
+ 'BCH',
23
+ 'BLU',
24
+ 'CHA',
25
+ 'CMB',
26
+ 'CSG',
27
+ 'ECS',
28
+ 'HST',
29
+ 'JIV',
30
+ 'LHP',
31
+ 'QST',
32
+ 'RMB',
33
+ 'SFT',
34
+ 'SLS',
35
+ 'SMB',
36
+ 'SWZ',
37
+ 'TGO',
38
+ 'VWZ',
39
+ 'WCS']
40
 
41
  def get_timestamp() -> str:
42
  return datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
 
135
 
136
 
137
  def cross_validation(seed=42, batch_size=64, k=5, device="mps"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  df = pd.read_csv("data/songs.csv")
139
+ x,y = get_examples(df, "data/samples",class_list=TARGET_CLASSES)
140
 
141
  dataset = SongDataset(x,y)
142
  splits=KFold(n_splits=k,shuffle=True,random_state=seed)
 
149
  train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
150
  test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
151
  n_classes = len(y[0])
152
+ model = ResidualDancer(n_classes=n_classes).to(device)
153
  model, _ = train(model,train_loader, epochs=2, device=device)
154
  val_metrics = evaluate(model, test_loader, nn.BCELoss())
155
  metrics.append(val_metrics)
 
164
 
165
 
166
  def train_model():
167
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  df = pd.read_csv("data/songs.csv")
169
+ x,y = get_examples(df, "data/samples",class_list=TARGET_CLASSES)
170
  dataset = SongDataset(x,y)
171
  train_count = int(len(dataset) * 0.9)
172
  datasets = random_split(dataset, [train_count, len(dataset) - train_count], torch.Generator().manual_seed(SEED))
 
174
  train_data, val_data = data_loaders
175
  example_spec, example_label = dataset[0]
176
  n_classes = len(example_label)
177
+ model = ResidualDancer(n_classes=n_classes).to(DEVICE)
178
  model, metrics = train(model,train_data, val_data, epochs=3, device=DEVICE)
179
 
180
  log_dir = os.path.join(
 
182
  )
183
  os.makedirs(log_dir, exist_ok=True)
184
 
185
+ torch.save(model.state_dict(), os.path.join(log_dir, "residual_dancer.pt"))
186
  metrics = pd.DataFrame(metrics)
187
  metrics.to_csv(os.path.join(log_dir, "metrics.csv"))
188
  config = {
189
+ "classes": TARGET_CLASSES
190
  }
191
  with open(os.path.join(log_dir, "config.json")) as f:
192
  json.dump(config, f)