Christian J. Steinmetz commited on
Commit
d254115
1 Parent(s): e8eaf47

adding updated support for classifier models

Browse files
cfg/exp/5-5_cls.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ render_files: True
10
+ render_root: "/scratch/EffectSet_cjs"
11
+ accelerator: "gpu"
12
+ log_audio: False
13
+ # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [0,5] # [min, max]
16
+ shuffle_kept_effects: True
17
+ shuffle_removed_effects: True
18
+ num_classes: 5
19
+ effects_to_keep:
20
+ effects_to_remove:
21
+ - distortion
22
+ - compressor
23
+ - reverb
24
+ - chorus
25
+ - delay
26
+ datamodule:
27
+ batch_size: 64
28
+ num_workers: 8
29
+
30
+ callbacks:
31
+ model_checkpoint:
32
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
33
+ monitor: "valid_f1_avg_epoch" # name of the logged metric which determines when model is improving
34
+ save_top_k: 1 # save k best models (determined by above metric)
35
+ save_last: True # additionaly always save model from last epoch
36
+ mode: "max" # can be "max" or "min"
37
+ verbose: True
38
+ dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
39
+ filename: '{epoch:02d}-{valid_f1_avg_epoch:.3f}'
40
+ learning_rate_monitor:
41
+ _target_: pytorch_lightning.callbacks.LearningRateMonitor
42
+ logging_interval: "step"
43
+ #audio_logging:
44
+ # _target_: remfx.callbacks.AudioCallback
45
+ # sample_rate: ${sample_rate}
46
+ # log_audio: ${log_audio}
47
+
48
+
49
+ trainer:
50
+ _target_: pytorch_lightning.Trainer
51
+ precision: 32 # Precision used for tensors, default `32`
52
+ min_epochs: 0
53
+ max_epochs: -1
54
+ log_every_n_steps: 1 # Logs metrics every N batches
55
+ accumulate_grad_batches: 1
56
+ accelerator: ${accelerator}
57
+ devices: 1
58
+ gradient_clip_val: 10.0
59
+ max_steps: 150000
cfg/model/cls_panns_16k.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ network:
8
+ _target_: remfx.classifier.Cnn14
9
+ num_classes: ${num_classes}
10
+ n_fft: 2048
11
+ hop_length: 512
12
+ n_mels: 128
13
+ sample_rate: 44100
14
+ model_sample_rate: 16000
15
+
cfg/model/cls_panns_44k.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ network:
8
+ _target_: remfx.classifier.Cnn14
9
+ num_classes: ${num_classes}
10
+ n_fft: 1024
11
+ hop_length: 256
12
+ n_mels: 128
13
+ sample_rate: 44100
14
+ model_sample_rate: 44100
15
+ specaugment: True
cfg/model/{classifier.yaml → cls_panns_pt.yaml} RENAMED
@@ -1,14 +1,11 @@
1
  # @package _global_
2
  model:
3
  _target_: remfx.models.FXClassifier
4
- lr: 1e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
  network:
8
- _target_: remfx.cnn14.Cnn14
9
  num_classes: ${num_classes}
10
- n_fft: 4096
11
- hop_length: 512
12
- n_mels: 128
13
  sample_rate: ${sample_rate}
14
 
 
1
  # @package _global_
2
  model:
3
  _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
  network:
8
+ _target_: remfx.classifier.PANNs
9
  num_classes: ${num_classes}
 
 
 
10
  sample_rate: ${sample_rate}
11
 
cfg/model/cls_vggish.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ network:
8
+ _target_: remfx.classifier.VGGish
9
+ num_classes: ${num_classes}
10
+ sample_rate: ${sample_rate}
11
+
cfg/model/cls_wav2clip.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ network:
8
+ _target_: remfx.classifier.Wav2CLIP
9
+ num_classes: ${num_classes}
10
+ sample_rate: ${sample_rate}
11
+
cfg/model/cls_wav2vec2.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ network:
8
+ _target_: remfx.classifier.wav2vec2
9
+ num_classes: ${num_classes}
10
+ sample_rate: ${sample_rate}
11
+
remfx/{cnn14.py → classifier.py} RENAMED
@@ -1,8 +1,132 @@
1
  import torch
2
  import torchaudio
3
  import torch.nn as nn
 
 
 
 
 
 
 
 
4
  import torch.nn.functional as F
5
- from utils import init_bn, init_layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
8
 
@@ -12,20 +136,25 @@ class Cnn14(nn.Module):
12
  self,
13
  num_classes: int,
14
  sample_rate: float,
15
- n_fft: int = 2048,
16
- hop_length: int = 512,
 
17
  n_mels: int = 128,
 
18
  ):
19
  super().__init__()
20
  self.num_classes = num_classes
21
  self.n_fft = n_fft
22
  self.hop_length = hop_length
 
 
 
23
 
24
  window = torch.hann_window(n_fft)
25
  self.register_buffer("window", window)
26
 
27
  self.melspec = torchaudio.transforms.MelSpectrogram(
28
- sample_rate,
29
  n_fft,
30
  hop_length=hop_length,
31
  n_mels=n_mels,
@@ -45,42 +174,56 @@ class Cnn14(nn.Module):
45
 
46
  self.init_weight()
47
 
 
 
 
 
 
48
  def init_weight(self):
49
  init_bn(self.bn0)
50
  init_layer(self.fc1)
51
  init_layer(self.fc_audioset)
52
 
53
- def forward(self, x: torch.Tensor):
54
  """
55
  Input: (batch_size, data_length)"""
56
 
 
 
 
57
  x = self.melspec(x)
 
 
 
 
 
 
 
 
 
 
58
  x = x.permute(0, 2, 1, 3)
59
  x = self.bn0(x)
60
  x = x.permute(0, 2, 1, 3)
61
 
62
- if self.training:
63
- pass
64
- # x = self.spec_augmenter(x)
65
-
66
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
67
- x = F.dropout(x, p=0.2, training=self.training)
68
  x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
69
- x = F.dropout(x, p=0.2, training=self.training)
70
  x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
71
- x = F.dropout(x, p=0.2, training=self.training)
72
  x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
73
- x = F.dropout(x, p=0.2, training=self.training)
74
  x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
75
- x = F.dropout(x, p=0.2, training=self.training)
76
  x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
77
- x = F.dropout(x, p=0.2, training=self.training)
78
  x = torch.mean(x, dim=3)
79
 
80
  (x1, _) = torch.max(x, dim=2)
81
  x2 = torch.mean(x, dim=2)
82
  x = x1 + x2
83
- x = F.dropout(x, p=0.5, training=self.training)
84
  x = F.relu_(self.fc1(x))
85
  clipwise_output = self.fc_audioset(x)
86
 
 
1
  import torch
2
  import torchaudio
3
  import torch.nn as nn
4
+ import hearbaseline
5
+ import hearbaseline.vggish
6
+ import hearbaseline.wav2vec2
7
+
8
+ import wav2clip_hear
9
+ import panns_hear
10
+
11
+
12
  import torch.nn.functional as F
13
+ from remfx.utils import init_bn, init_layer
14
+
15
+
16
+ class PANNs(torch.nn.Module):
17
+ def __init__(
18
+ self, num_classes: int, sample_rate: float, hidden_dim: int = 256
19
+ ) -> None:
20
+ super().__init__()
21
+ self.num_classes = num_classes
22
+ self.model = panns_hear.load_model("hear2021-panns_hear.pth")
23
+ self.resample = torchaudio.transforms.Resample(
24
+ orig_freq=sample_rate, new_freq=32000
25
+ )
26
+ self.proj = torch.nn.Sequential(
27
+ torch.nn.Linear(2048, hidden_dim),
28
+ torch.nn.ReLU(),
29
+ torch.nn.Linear(hidden_dim, hidden_dim),
30
+ torch.nn.ReLU(),
31
+ torch.nn.Linear(hidden_dim, num_classes),
32
+ )
33
+
34
+ def forward(self, x: torch.Tensor):
35
+ with torch.no_grad():
36
+ x = self.resample(x)
37
+ embed = panns_hear.get_scene_embeddings(x.view(x.shape[0], -1), self.model)
38
+ return self.proj(embed)
39
+
40
+
41
+ class Wav2CLIP(nn.Module):
42
+ def __init__(
43
+ self,
44
+ num_classes: int,
45
+ sample_rate: float,
46
+ hidden_dim: int = 256,
47
+ ) -> None:
48
+ super().__init__()
49
+ self.num_classes = num_classes
50
+ self.model = wav2clip_hear.load_model("")
51
+ self.resample = torchaudio.transforms.Resample(
52
+ orig_freq=sample_rate, new_freq=16000
53
+ )
54
+ self.proj = torch.nn.Sequential(
55
+ torch.nn.Linear(512, hidden_dim),
56
+ torch.nn.ReLU(),
57
+ torch.nn.Linear(hidden_dim, hidden_dim),
58
+ torch.nn.ReLU(),
59
+ torch.nn.Linear(hidden_dim, num_classes),
60
+ )
61
+
62
+ def forward(self, x: torch.Tensor):
63
+ with torch.no_grad():
64
+ x = self.resample(x)
65
+ embed = wav2clip_hear.get_scene_embeddings(
66
+ x.view(x.shape[0], -1), self.model
67
+ )
68
+ return self.proj(embed)
69
+
70
+
71
+ class VGGish(nn.Module):
72
+ def __init__(
73
+ self,
74
+ num_classes: int,
75
+ sample_rate: float,
76
+ hidden_dim: int = 256,
77
+ ):
78
+ super().__init__()
79
+ self.num_classes = num_classes
80
+ self.resample = torchaudio.transforms.Resample(
81
+ orig_freq=sample_rate, new_freq=16000
82
+ )
83
+ self.model = hearbaseline.vggish.load_model()
84
+ self.proj = torch.nn.Sequential(
85
+ torch.nn.Linear(128, hidden_dim),
86
+ torch.nn.ReLU(),
87
+ torch.nn.Linear(hidden_dim, hidden_dim),
88
+ torch.nn.ReLU(),
89
+ torch.nn.Linear(hidden_dim, num_classes),
90
+ )
91
+
92
+ def forward(self, x: torch.Tensor):
93
+ with torch.no_grad():
94
+ x = self.resample(x)
95
+ embed = hearbaseline.vggish.get_scene_embeddings(
96
+ x.view(x.shape[0], -1), self.model
97
+ )
98
+ return self.proj(embed)
99
+
100
+
101
+ class wav2vec2(nn.Module):
102
+ def __init__(
103
+ self,
104
+ num_classes: int,
105
+ sample_rate: float,
106
+ hidden_dim: int = 256,
107
+ ):
108
+ super().__init__()
109
+ self.num_classes = num_classes
110
+ self.resample = torchaudio.transforms.Resample(
111
+ orig_freq=sample_rate, new_freq=16000
112
+ )
113
+ self.model = hearbaseline.wav2vec2.load_model()
114
+ self.proj = torch.nn.Sequential(
115
+ torch.nn.Linear(1024, hidden_dim),
116
+ torch.nn.ReLU(),
117
+ torch.nn.Linear(hidden_dim, hidden_dim),
118
+ torch.nn.ReLU(),
119
+ torch.nn.Linear(hidden_dim, num_classes),
120
+ )
121
+
122
+ def forward(self, x: torch.Tensor):
123
+ with torch.no_grad():
124
+ x = self.resample(x)
125
+ embed = hearbaseline.wav2vec2.get_scene_embeddings(
126
+ x.view(x.shape[0], -1), self.model
127
+ )
128
+ return self.proj(embed)
129
+
130
 
131
  # adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
132
 
 
136
  self,
137
  num_classes: int,
138
  sample_rate: float,
139
+ model_sample_rate: float,
140
+ n_fft: int = 1024,
141
+ hop_length: int = 256,
142
  n_mels: int = 128,
143
+ specaugment: bool = False,
144
  ):
145
  super().__init__()
146
  self.num_classes = num_classes
147
  self.n_fft = n_fft
148
  self.hop_length = hop_length
149
+ self.sample_rate = sample_rate
150
+ self.model_sample_rate = model_sample_rate
151
+ self.specaugment = specaugment
152
 
153
  window = torch.hann_window(n_fft)
154
  self.register_buffer("window", window)
155
 
156
  self.melspec = torchaudio.transforms.MelSpectrogram(
157
+ model_sample_rate,
158
  n_fft,
159
  hop_length=hop_length,
160
  n_mels=n_mels,
 
174
 
175
  self.init_weight()
176
 
177
+ if sample_rate != model_sample_rate:
178
+ self.resample = torchaudio.transforms.Resample(
179
+ orig_freq=sample_rate, new_freq=model_sample_rate
180
+ )
181
+
182
  def init_weight(self):
183
  init_bn(self.bn0)
184
  init_layer(self.fc1)
185
  init_layer(self.fc_audioset)
186
 
187
+ def forward(self, x: torch.Tensor, train: bool = False):
188
  """
189
  Input: (batch_size, data_length)"""
190
 
191
+ if self.sample_rate != self.model_sample_rate:
192
+ x = self.resample(x)
193
+
194
  x = self.melspec(x)
195
+
196
+ if self.specaugment and train:
197
+ # import matplotlib.pyplot as plt
198
+ # fig, axs = plt.subplots(2, 1, sharex=True)
199
+ # axs[0].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
200
+ x = self.freq_mask(x)
201
+ x = self.time_mask(x)
202
+ # axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
203
+ # plt.savefig("spec_augment.png", dpi=300)
204
+
205
  x = x.permute(0, 2, 1, 3)
206
  x = self.bn0(x)
207
  x = x.permute(0, 2, 1, 3)
208
 
 
 
 
 
209
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
210
+ x = F.dropout(x, p=0.2, training=train)
211
  x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
212
+ x = F.dropout(x, p=0.2, training=train)
213
  x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
214
+ x = F.dropout(x, p=0.2, training=train)
215
  x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
216
+ x = F.dropout(x, p=0.2, training=train)
217
  x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
218
+ x = F.dropout(x, p=0.2, training=train)
219
  x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
220
+ x = F.dropout(x, p=0.2, training=train)
221
  x = torch.mean(x, dim=3)
222
 
223
  (x1, _) = torch.max(x, dim=2)
224
  x2 = torch.mean(x, dim=2)
225
  x = x1 + x2
226
+ x = F.dropout(x, p=0.5, training=train)
227
  x = F.relu_(self.fc1(x))
228
  clipwise_output = self.fc_audioset(x)
229
 
setup.py CHANGED
@@ -1,8 +1,8 @@
1
  from pathlib import Path
2
  from setuptools import setup, find_packages
3
 
4
- NAME = "REMFX"
5
- DESCRIPTION = ""
6
  URL = ""
7
  EMAIL = "m.rice@se22.qmul.ac.uk"
8
  AUTHOR = "Matthew Rice"
 
1
  from pathlib import Path
2
  from setuptools import setup, find_packages
3
 
4
+ NAME = "remfx"
5
+ DESCRIPTION = "Universal audio effect removal"
6
  URL = ""
7
  EMAIL = "m.rice@se22.qmul.ac.uk"
8
  AUTHOR = "Matthew Rice"
train_all.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ python scripts/train.py +exp=5-5_cls.yaml model=cls_wav2vec2 render_files=False logs_dir=/scratch/cjs-log
2
+ python scripts/train.py +exp=5-5_cls.yaml model=cls_panns_44k render_files=False logs_dir=/scratch/cjs-log
3
+ python scripts/train.py +exp=5-5_cls.yaml model=cls_panns_16k render_files=False logs_dir=/scratch/cjs-log
4
+ python scripts/train.py +exp=5-5_cls.yaml model=cls_panns_pt render_files=False logs_dir=/scratch/cjs-log
5
+ python scripts/train.py +exp=5-5_cls.yaml model=cls_vggish render_files=False logs_dir=/scratch/cjs-log
6
+ python scripts/train.py +exp=5-5_cls.yaml model=cls_wav2clip render_files=False logs_dir=/scratch/cjs-log