mattricesound commited on
Commit
e4fc05d
·
2 Parent(s): 35ce5ba 6448f47

Merge pull request #29 from mhrice/cjs--classifier-v2

Browse files
README.md CHANGED
@@ -9,10 +9,9 @@
9
  5. `pip install -e umx`
10
 
11
  ## Download [VocalSet Dataset](https://zenodo.org/record/1193957)
12
- 1. `wget https://zenodo.org/record/1193957/files/VocalSet.zip?download=1`
13
  2. `mv VocalSet.zip?download=1 VocalSet.zip`
14
  3. `unzip VocalSet.zip`
15
- 4. Manually split singers into train, val, test directories
16
 
17
  # Training
18
  ## Steps
 
9
  5. `pip install -e umx`
10
 
11
  ## Download [VocalSet Dataset](https://zenodo.org/record/1193957)
12
+ 1. `wget https://zenodo.org/record/1442513/files/VocalSet1-2.zip?download=1`
13
  2. `mv VocalSet.zip?download=1 VocalSet.zip`
14
  3. `unzip VocalSet.zip`
 
15
 
16
  # Training
17
  ## Steps
cfg/config.yaml CHANGED
@@ -16,6 +16,7 @@ max_kept_effects: -1
16
  max_removed_effects: -1
17
  shuffle_kept_effects: True
18
  shuffle_removed_effects: False
 
19
  effects_to_use:
20
  - compressor
21
  - distortion
 
16
  max_removed_effects: -1
17
  shuffle_kept_effects: True
18
  shuffle_removed_effects: False
19
+ num_classes: 4
20
  effects_to_use:
21
  - compressor
22
  - distortion
cfg/model/classifier.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 1e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ network:
8
+ _target_: remfx.models.Cnn14
9
+ num_classes: ${num_classes}
10
+ n_fft: 4096
11
+ hop_length: 512
12
+ n_mels: 128
13
+ sample_rate: ${sample_rate}
14
+
remfx/datasets.py CHANGED
@@ -1,16 +1,19 @@
 
 
 
1
  import torch
2
- from torch.utils.data import Dataset, DataLoader
3
- import torch.nn.functional as F
4
  import torchaudio
5
- from pathlib import Path
6
  import pytorch_lightning as pl
7
- import sys
8
- from typing import Any, List, Dict
9
- from remfx import effects
10
  from tqdm import tqdm
11
- from remfx.utils import create_sequential_chunks
12
- import shutil
13
  from ordered_set import OrderedSet
 
 
 
14
 
15
 
16
  # https://zenodo.org/record/1193957 -> VocalSet
@@ -18,6 +21,30 @@ from ordered_set import OrderedSet
18
  ALL_EFFECTS = effects.Pedalboard_Effects
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class VocalSet(Dataset):
22
  def __init__(
23
  self,
@@ -43,8 +70,6 @@ class VocalSet(Dataset):
43
  self.chunk_size = chunk_size
44
  self.sample_rate = sample_rate
45
  self.mode = mode
46
- mode_path = self.root / self.mode
47
- self.files = sorted(list(mode_path.glob("./**/*.wav")))
48
  self.max_kept_effects = max_kept_effects
49
  self.max_removed_effects = max_removed_effects
50
  self.effects_to_use = effects_to_use
@@ -53,11 +78,20 @@ class VocalSet(Dataset):
53
  self.effects = effect_modules
54
  self.shuffle_kept_effects = shuffle_kept_effects
55
  self.shuffle_removed_effects = shuffle_removed_effects
56
-
57
  effects_string = "_".join(self.effects_to_use + ["_"] + self.effects_to_remove)
58
  self.effects_to_keep = self.validate_effect_input()
59
  self.proc_root = self.render_root / "processed" / effects_string / self.mode
60
 
 
 
 
 
 
 
 
 
 
 
61
  if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
62
  print("Found processed files.")
63
  if render_files:
@@ -86,12 +120,15 @@ class VocalSet(Dataset):
86
  # Skip if chunk is too small
87
  continue
88
 
89
- dry, wet, effect = self.process_effects(resampled_chunk)
 
 
90
  output_dir = self.proc_root / str(self.num_chunks)
91
  output_dir.mkdir(exist_ok=True)
92
  torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
93
  torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
94
- torch.save(effect, output_dir / "effect.pt")
 
95
  self.num_chunks += 1
96
  else:
97
  self.num_chunks = len(list(self.proc_root.iterdir()))
@@ -107,10 +144,11 @@ class VocalSet(Dataset):
107
  def __getitem__(self, idx):
108
  input_file = self.proc_root / str(idx) / "input.wav"
109
  target_file = self.proc_root / str(idx) / "target.wav"
110
- effect_name = torch.load(self.proc_root / str(idx) / "effect.pt")
 
111
  input, sr = torchaudio.load(input_file)
112
  target, sr = torchaudio.load(target_file)
113
- return (input, target, effect_name)
114
 
115
  def validate_effect_input(self):
116
  for effect in self.effects.values():
@@ -154,27 +192,28 @@ class VocalSet(Dataset):
154
  return kept_fx
155
 
156
  def process_effects(self, dry: torch.Tensor):
157
- labels = []
158
-
159
  # Apply Kept Effects
160
  # Shuffle effects if specified
161
  if self.shuffle_kept_effects:
162
  effect_indices = torch.randperm(len(self.effects_to_keep))
163
  else:
164
  effect_indices = torch.arange(len(self.effects_to_keep))
 
165
  # Up to max_kept_effects
166
  if self.max_kept_effects != -1:
167
  num_kept_effects = int(torch.rand(1).item() * (self.max_kept_effects)) + 1
168
  else:
169
  num_kept_effects = len(self.effects_to_keep)
170
  effect_indices = effect_indices[:num_kept_effects]
 
171
  # Index in effect settings
172
  effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
173
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
174
  # Apply
 
175
  for effect in effects_to_apply:
176
  dry = effect(dry)
177
- labels.append(ALL_EFFECTS.index(type(effect)))
178
 
179
  # Apply effects_to_remove
180
  # Shuffle effects if specified
@@ -185,9 +224,7 @@ class VocalSet(Dataset):
185
  effect_indices = torch.arange(len(self.effects_to_remove))
186
  # Up to max_removed_effects
187
  if self.max_removed_effects != -1:
188
- num_kept_effects = (
189
- int(torch.rand(1).item() * (self.max_removed_effects)) + 1
190
- )
191
  else:
192
  num_kept_effects = len(self.effects_to_remove)
193
  effect_indices = effect_indices[: self.max_removed_effects]
@@ -195,17 +232,25 @@ class VocalSet(Dataset):
195
  effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
196
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
197
  # Apply
 
 
198
  for effect in effects_to_apply:
199
  wet = effect(wet)
200
- labels.append(ALL_EFFECTS.index(type(effect)))
 
 
 
 
 
 
 
 
 
201
 
202
- # Convert labels to one-hot
203
- one_hot = F.one_hot(torch.tensor(labels), num_classes=len(ALL_EFFECTS))
204
- effects_present = torch.sum(one_hot, dim=0).float()
205
  # Normalize
206
  normalized_dry = self.normalize(dry)
207
  normalized_wet = self.normalize(wet)
208
- return normalized_dry, normalized_wet, effects_present
209
 
210
 
211
  class VocalSetDatamodule(pl.LightningDataModule):
 
1
+ import os
2
+ import sys
3
+ import glob
4
  import torch
5
+ import shutil
 
6
  import torchaudio
 
7
  import pytorch_lightning as pl
8
+ import torch.nn.functional as F
9
+
 
10
  from tqdm import tqdm
11
+ from pathlib import Path
12
+ from remfx import effects
13
  from ordered_set import OrderedSet
14
+ from typing import Any, List, Dict
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from remfx.utils import create_sequential_chunks
17
 
18
 
19
  # https://zenodo.org/record/1193957 -> VocalSet
 
21
  ALL_EFFECTS = effects.Pedalboard_Effects
22
 
23
 
24
+ singer_splits = {
25
+ "train": [
26
+ "male1",
27
+ "male2",
28
+ "male3",
29
+ "male4",
30
+ "male5",
31
+ "male6",
32
+ "male7",
33
+ "male8",
34
+ "male9",
35
+ "female1",
36
+ "female2",
37
+ "female3",
38
+ "female4",
39
+ "female5",
40
+ "female6",
41
+ "female7",
42
+ ],
43
+ "val": ["male10", "female8"],
44
+ "test": ["male11", "female9"],
45
+ }
46
+
47
+
48
  class VocalSet(Dataset):
49
  def __init__(
50
  self,
 
70
  self.chunk_size = chunk_size
71
  self.sample_rate = sample_rate
72
  self.mode = mode
 
 
73
  self.max_kept_effects = max_kept_effects
74
  self.max_removed_effects = max_removed_effects
75
  self.effects_to_use = effects_to_use
 
78
  self.effects = effect_modules
79
  self.shuffle_kept_effects = shuffle_kept_effects
80
  self.shuffle_removed_effects = shuffle_removed_effects
 
81
  effects_string = "_".join(self.effects_to_use + ["_"] + self.effects_to_remove)
82
  self.effects_to_keep = self.validate_effect_input()
83
  self.proc_root = self.render_root / "processed" / effects_string / self.mode
84
 
85
+ # find all singer directories
86
+ singer_dirs = glob.glob(os.path.join(self.root, "data_by_singer", "*"))
87
+ singer_dirs = [
88
+ sd for sd in singer_dirs if os.path.basename(sd) in singer_splits[mode]
89
+ ]
90
+ self.files = []
91
+ for singer_dir in singer_dirs:
92
+ self.files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
93
+ self.files = sorted(self.files)
94
+
95
  if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
96
  print("Found processed files.")
97
  if render_files:
 
120
  # Skip if chunk is too small
121
  continue
122
 
123
+ dry, wet, dry_effects, wet_effects = self.process_effects(
124
+ resampled_chunk
125
+ )
126
  output_dir = self.proc_root / str(self.num_chunks)
127
  output_dir.mkdir(exist_ok=True)
128
  torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
129
  torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
130
+ torch.save(dry_effects, output_dir / "dry_effects.pt")
131
+ torch.save(wet_effects, output_dir / "wet_effects.pt")
132
  self.num_chunks += 1
133
  else:
134
  self.num_chunks = len(list(self.proc_root.iterdir()))
 
144
  def __getitem__(self, idx):
145
  input_file = self.proc_root / str(idx) / "input.wav"
146
  target_file = self.proc_root / str(idx) / "target.wav"
147
+ dry_effect_names = torch.load(self.proc_root / str(idx) / "dry_effects.pt")
148
+ wet_effect_names = torch.load(self.proc_root / str(idx) / "wet_effects.pt")
149
  input, sr = torchaudio.load(input_file)
150
  target, sr = torchaudio.load(target_file)
151
+ return (input, target, dry_effect_names, wet_effect_names)
152
 
153
  def validate_effect_input(self):
154
  for effect in self.effects.values():
 
192
  return kept_fx
193
 
194
  def process_effects(self, dry: torch.Tensor):
 
 
195
  # Apply Kept Effects
196
  # Shuffle effects if specified
197
  if self.shuffle_kept_effects:
198
  effect_indices = torch.randperm(len(self.effects_to_keep))
199
  else:
200
  effect_indices = torch.arange(len(self.effects_to_keep))
201
+
202
  # Up to max_kept_effects
203
  if self.max_kept_effects != -1:
204
  num_kept_effects = int(torch.rand(1).item() * (self.max_kept_effects)) + 1
205
  else:
206
  num_kept_effects = len(self.effects_to_keep)
207
  effect_indices = effect_indices[:num_kept_effects]
208
+
209
  # Index in effect settings
210
  effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
211
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
212
  # Apply
213
+ dry_labels = []
214
  for effect in effects_to_apply:
215
  dry = effect(dry)
216
+ dry_labels.append(ALL_EFFECTS.index(type(effect)))
217
 
218
  # Apply effects_to_remove
219
  # Shuffle effects if specified
 
224
  effect_indices = torch.arange(len(self.effects_to_remove))
225
  # Up to max_removed_effects
226
  if self.max_removed_effects != -1:
227
+ num_kept_effects = int(torch.rand(1).item() * (self.max_removed_effects))
 
 
228
  else:
229
  num_kept_effects = len(self.effects_to_remove)
230
  effect_indices = effect_indices[: self.max_removed_effects]
 
232
  effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
233
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
234
  # Apply
235
+
236
+ wet_labels = []
237
  for effect in effects_to_apply:
238
  wet = effect(wet)
239
+ wet_labels.append(ALL_EFFECTS.index(type(effect)))
240
+
241
+ wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
242
+ dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
243
+
244
+ for label_idx in wet_labels:
245
+ wet_labels_tensor[label_idx] = 1.0
246
+
247
+ for label_idx in dry_labels:
248
+ dry_labels_tensor[label_idx] = 1.0
249
 
 
 
 
250
  # Normalize
251
  normalized_dry = self.normalize(dry)
252
  normalized_wet = self.normalize(wet)
253
+ return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
254
 
255
 
256
  class VocalSetDatamodule(pl.LightningDataModule):
remfx/effects.py CHANGED
@@ -701,9 +701,9 @@ class RandomAudioEffectsChannel(torch.nn.Module):
701
  Pedalboard_Effects = [
702
  RandomPedalboardReverb,
703
  RandomPedalboardChorus,
704
- RandomPedalboardDelay,
705
  RandomPedalboardDistortion,
706
  RandomPedalboardCompressor,
707
- RandomPedalboardPhaser,
708
- RandomPedalboardLimiter,
709
  ]
 
701
  Pedalboard_Effects = [
702
  RandomPedalboardReverb,
703
  RandomPedalboardChorus,
704
+ # RandomPedalboardDelay,
705
  RandomPedalboardDistortion,
706
  RandomPedalboardCompressor,
707
+ # RandomPedalboardPhaser,
708
+ # RandomPedalboardLimiter,
709
  ]
remfx/models.py CHANGED
@@ -1,15 +1,19 @@
 
1
  import torch
2
- from torch import Tensor, nn
 
3
  import pytorch_lightning as pl
 
 
 
4
  from einops import rearrange
5
- import wandb
6
  from audio_diffusion_pytorch import DiffusionModel
7
  from auraloss.time import SISDRLoss
8
  from auraloss.freq import MultiResolutionSTFTLoss
9
- from remfx.utils import FADLoss
10
-
11
  from umx.openunmix.model import OpenUnmix, Separator
12
- from torchaudio.models import HDemucs
 
13
 
14
 
15
  class RemFXModel(pl.LightningModule):
@@ -90,9 +94,9 @@ class RemFXModel(pl.LightningModule):
90
  return loss
91
 
92
  def common_step(self, batch, batch_idx, mode: str = "train"):
93
- loss, output = self.model(batch)
 
94
  self.log(f"{mode}_loss", loss)
95
- x, y, label = batch
96
  # Metric logging
97
  with torch.no_grad():
98
  for metric in self.metrics:
@@ -119,7 +123,7 @@ class RemFXModel(pl.LightningModule):
119
  def on_train_batch_start(self, batch, batch_idx):
120
  # Log initial audio
121
  if self.log_train_audio:
122
- x, y, label = batch
123
  # Concat samples together for easier viewing in dashboard
124
  input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
125
  target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
@@ -141,7 +145,7 @@ class RemFXModel(pl.LightningModule):
141
  self.log_train_audio = False
142
 
143
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
144
- x, target, label = batch
145
  # Log Input Metrics
146
  for metric in self.metrics:
147
  # SISDR returns negative values, so negate them
@@ -185,7 +189,7 @@ class RemFXModel(pl.LightningModule):
185
  def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
186
  self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
187
  # Log FAD
188
- x, target, label = batch
189
  self.log(
190
  "Input_FAD",
191
  self.metrics["FAD"](x, target),
@@ -233,7 +237,7 @@ class OpenUnmixModel(torch.nn.Module):
233
  self.l1loss = torch.nn.L1Loss()
234
 
235
  def forward(self, batch):
236
- x, target, label = batch
237
  X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
238
  Y = self.model(X)
239
  sep_out = self.separator(x).squeeze(1)
@@ -256,7 +260,7 @@ class DemucsModel(torch.nn.Module):
256
  self.l1loss = torch.nn.L1Loss()
257
 
258
  def forward(self, batch):
259
- x, target, label = batch
260
  output = self.model(x).squeeze(1)
261
  loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
262
  return loss, output
@@ -271,7 +275,7 @@ class DiffusionGenerationModel(nn.Module):
271
  self.model = DiffusionModel(in_channels=n_channels)
272
 
273
  def forward(self, batch):
274
- x, target, label = batch
275
  sampled_out = self.model.sample(x)
276
  return self.model(x), sampled_out
277
 
@@ -326,3 +330,215 @@ def spectrogram(
326
  X = X.view(bs, chs, X.shape[-2], X.shape[-1])
327
 
328
  return torch.pow(X.abs() + 1e-8, alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
  import torch
3
+ import torchaudio
4
+ import torchmetrics
5
  import pytorch_lightning as pl
6
+ import torch.nn.functional as F
7
+
8
+ from torch import Tensor, nn
9
  from einops import rearrange
10
+ from torchaudio.models import HDemucs
11
  from audio_diffusion_pytorch import DiffusionModel
12
  from auraloss.time import SISDRLoss
13
  from auraloss.freq import MultiResolutionSTFTLoss
 
 
14
  from umx.openunmix.model import OpenUnmix, Separator
15
+
16
+ from remfx.utils import FADLoss
17
 
18
 
19
  class RemFXModel(pl.LightningModule):
 
94
  return loss
95
 
96
  def common_step(self, batch, batch_idx, mode: str = "train"):
97
+ x, y, _, _ = batch
98
+ loss, output = self.model((x, y))
99
  self.log(f"{mode}_loss", loss)
 
100
  # Metric logging
101
  with torch.no_grad():
102
  for metric in self.metrics:
 
123
  def on_train_batch_start(self, batch, batch_idx):
124
  # Log initial audio
125
  if self.log_train_audio:
126
+ x, y, _, _ = batch
127
  # Concat samples together for easier viewing in dashboard
128
  input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
129
  target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
 
145
  self.log_train_audio = False
146
 
147
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
148
+ x, target, _, _ = batch
149
  # Log Input Metrics
150
  for metric in self.metrics:
151
  # SISDR returns negative values, so negate them
 
189
  def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
190
  self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
191
  # Log FAD
192
+ x, target, _, _ = batch
193
  self.log(
194
  "Input_FAD",
195
  self.metrics["FAD"](x, target),
 
237
  self.l1loss = torch.nn.L1Loss()
238
 
239
  def forward(self, batch):
240
+ x, target = batch
241
  X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
242
  Y = self.model(X)
243
  sep_out = self.separator(x).squeeze(1)
 
260
  self.l1loss = torch.nn.L1Loss()
261
 
262
  def forward(self, batch):
263
+ x, target = batch
264
  output = self.model(x).squeeze(1)
265
  loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
266
  return loss, output
 
275
  self.model = DiffusionModel(in_channels=n_channels)
276
 
277
  def forward(self, batch):
278
+ x, target = batch
279
  sampled_out = self.model.sample(x)
280
  return self.model(x), sampled_out
281
 
 
330
  X = X.view(bs, chs, X.shape[-2], X.shape[-1])
331
 
332
  return torch.pow(X.abs() + 1e-8, alpha)
333
+
334
+
335
+ # adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
336
+
337
+
338
+ def init_layer(layer):
339
+ """Initialize a Linear or Convolutional layer."""
340
+ nn.init.xavier_uniform_(layer.weight)
341
+
342
+ if hasattr(layer, "bias"):
343
+ if layer.bias is not None:
344
+ layer.bias.data.fill_(0.0)
345
+
346
+
347
+ def init_bn(bn):
348
+ """Initialize a Batchnorm layer."""
349
+ bn.bias.data.fill_(0.0)
350
+ bn.weight.data.fill_(1.0)
351
+
352
+
353
+ class ConvBlock(nn.Module):
354
+ def __init__(self, in_channels, out_channels):
355
+ super(ConvBlock, self).__init__()
356
+
357
+ self.conv1 = nn.Conv2d(
358
+ in_channels=in_channels,
359
+ out_channels=out_channels,
360
+ kernel_size=(3, 3),
361
+ stride=(1, 1),
362
+ padding=(1, 1),
363
+ bias=False,
364
+ )
365
+
366
+ self.conv2 = nn.Conv2d(
367
+ in_channels=out_channels,
368
+ out_channels=out_channels,
369
+ kernel_size=(3, 3),
370
+ stride=(1, 1),
371
+ padding=(1, 1),
372
+ bias=False,
373
+ )
374
+
375
+ self.bn1 = nn.BatchNorm2d(out_channels)
376
+ self.bn2 = nn.BatchNorm2d(out_channels)
377
+
378
+ self.init_weight()
379
+
380
+ def init_weight(self):
381
+ init_layer(self.conv1)
382
+ init_layer(self.conv2)
383
+ init_bn(self.bn1)
384
+ init_bn(self.bn2)
385
+
386
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
387
+ x = input
388
+ x = F.relu_(self.bn1(self.conv1(x)))
389
+ x = F.relu_(self.bn2(self.conv2(x)))
390
+ if pool_type == "max":
391
+ x = F.max_pool2d(x, kernel_size=pool_size)
392
+ elif pool_type == "avg":
393
+ x = F.avg_pool2d(x, kernel_size=pool_size)
394
+ elif pool_type == "avg+max":
395
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
396
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
397
+ x = x1 + x2
398
+ else:
399
+ raise Exception("Incorrect argument!")
400
+
401
+ return x
402
+
403
+
404
+ class Cnn14(nn.Module):
405
+ def __init__(
406
+ self,
407
+ num_classes: int,
408
+ sample_rate: float,
409
+ n_fft: int = 2048,
410
+ hop_length: int = 512,
411
+ n_mels: int = 128,
412
+ ):
413
+ super().__init__()
414
+ self.num_classes = num_classes
415
+ self.n_fft = n_fft
416
+ self.hop_length = hop_length
417
+
418
+ window = torch.hann_window(n_fft)
419
+ self.register_buffer("window", window)
420
+
421
+ self.melspec = torchaudio.transforms.MelSpectrogram(
422
+ sample_rate,
423
+ n_fft,
424
+ hop_length=hop_length,
425
+ n_mels=n_mels,
426
+ )
427
+
428
+ self.bn0 = nn.BatchNorm2d(n_mels)
429
+
430
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
431
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
432
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
433
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
434
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
435
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
436
+
437
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
438
+ self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
439
+
440
+ self.init_weight()
441
+
442
+ def init_weight(self):
443
+ init_bn(self.bn0)
444
+ init_layer(self.fc1)
445
+ init_layer(self.fc_audioset)
446
+
447
+ def forward(self, x: torch.Tensor):
448
+ """
449
+ Input: (batch_size, data_length)"""
450
+
451
+ x = self.melspec(x)
452
+ x = x.permute(0, 2, 1, 3)
453
+ x = self.bn0(x)
454
+ x = x.permute(0, 2, 1, 3)
455
+
456
+ if self.training:
457
+ pass
458
+ # x = self.spec_augmenter(x)
459
+
460
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
461
+ x = F.dropout(x, p=0.2, training=self.training)
462
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
463
+ x = F.dropout(x, p=0.2, training=self.training)
464
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
465
+ x = F.dropout(x, p=0.2, training=self.training)
466
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
467
+ x = F.dropout(x, p=0.2, training=self.training)
468
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
469
+ x = F.dropout(x, p=0.2, training=self.training)
470
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
471
+ x = F.dropout(x, p=0.2, training=self.training)
472
+ x = torch.mean(x, dim=3)
473
+
474
+ (x1, _) = torch.max(x, dim=2)
475
+ x2 = torch.mean(x, dim=2)
476
+ x = x1 + x2
477
+ x = F.dropout(x, p=0.5, training=self.training)
478
+ x = F.relu_(self.fc1(x))
479
+ clipwise_output = self.fc_audioset(x)
480
+
481
+ return clipwise_output
482
+
483
+
484
+ class FXClassifier(pl.LightningModule):
485
+ def __init__(
486
+ self,
487
+ lr: float,
488
+ lr_weight_decay: float,
489
+ sample_rate: float,
490
+ network: nn.Module,
491
+ ):
492
+ super().__init__()
493
+ self.lr = lr
494
+ self.lr_weight_decay = lr_weight_decay
495
+ self.sample_rate = sample_rate
496
+ self.network = network
497
+
498
+ def forward(self, x: torch.Tensor):
499
+ return self.network(x)
500
+
501
+ def common_step(self, batch, batch_idx, mode: str = "train"):
502
+ x, y, dry_label, wet_label = batch
503
+ pred_label = self.network(x)
504
+ loss = torch.nn.functional.cross_entropy(pred_label, dry_label)
505
+ self.log(
506
+ f"{mode}_loss",
507
+ loss,
508
+ on_step=True,
509
+ on_epoch=True,
510
+ prog_bar=True,
511
+ logger=True,
512
+ sync_dist=True,
513
+ )
514
+
515
+ self.log(
516
+ f"{mode}_mAP",
517
+ torchmetrics.functional.retrieval_average_precision(
518
+ pred_label, dry_label.long()
519
+ ),
520
+ on_step=True,
521
+ on_epoch=True,
522
+ prog_bar=True,
523
+ logger=True,
524
+ sync_dist=True,
525
+ )
526
+
527
+ return loss
528
+
529
+ def training_step(self, batch, batch_idx):
530
+ return self.common_step(batch, batch_idx, mode="train")
531
+
532
+ def validation_step(self, batch, batch_idx):
533
+ return self.common_step(batch, batch_idx, mode="valid")
534
+
535
+ def test_step(self, batch, batch_idx):
536
+ return self.common_step(batch, batch_idx, mode="test")
537
+
538
+ def configure_optimizers(self):
539
+ optimizer = torch.optim.AdamW(
540
+ self.network.parameters(),
541
+ lr=self.lr,
542
+ weight_decay=self.lr_weight_decay,
543
+ )
544
+ return optimizer