mattricesound commited on
Commit
d8f7979
β€’
1 Parent(s): 12e94ae

Restore classifier, move shell scripts to scripts

Browse files
README.md CHANGED
@@ -16,12 +16,12 @@ This repo can be used for many different tasks. Here are some examples.
16
  ## Run RemFX Detect on a single file
17
  First, need to download the checkpoints from [zenodo](https://zenodo.org/record/8179396)
18
  ```
19
- ./download_checkpoints.sh
20
- ./remfx_detect.sh wet.wav -o dry.wav
21
  ```
22
  ## Download the [General Purpose Audio Effect Removal evaluation datasets](https://zenodo.org/record/8187288)
23
  ```
24
- ./download_eval_datasets.sh
25
  ```
26
 
27
  ## Download the starter datasets
@@ -73,28 +73,28 @@ Also note that the training assumes you have a GPU. To train on CPU, set `accele
73
  First download the General Purpose Audio Effect Removal evaluation datasets (see above).
74
  To use the pretrained RemFX model, download the checkpoints
75
  ```
76
- ./download_checkpoints.sh
77
  ```
78
  Then run the evaluation script, select the RemFX configuration, between `remfx_oracle`, `remfx_detect`, and `remfx_all`. Then select N, the number of effects to remove.
79
  ```
80
- ./eval.sh remfx_detect 0-0
81
- ./eval.sh remfx_detect 1-1
82
- ./eval.sh remfx_detect 2-2
83
- ./eval.sh remfx_detect 3-3
84
- ./eval.sh remfx_detect 4-4
85
- ./eval.sh remfx_detect 5-5
86
 
87
  ```
88
  To eval a custom monolithic model, first train a model (see Training)
89
  Then run the evaluation script, with the config used and checkpoint_path.
90
  ```
91
- ./eval.sh distortion_aug 0-0 -ckpt "logs/ckpts/2023-07-26-10-10-27/epoch\=05-valid_loss\=8.623.ckpt"
92
  ```
93
 
94
  To eval a custom effect-specific model as part of the inference chain, first train a model (see Training), then edit `cfg/exp/remfx_{desired_configuration}.yaml -> ckpts -> {effect}`.
95
  Then run the evaluation script.
96
  ```
97
- ./eval.sh remfx_detect 0-0
98
  ```
99
 
100
  The script assumes that RemFX_eval_datasets is in the top-level directory.
 
16
  ## Run RemFX Detect on a single file
17
  First, need to download the checkpoints from [zenodo](https://zenodo.org/record/8179396)
18
  ```
19
+ scripts/download_checkpoints.sh
20
+ scripts/remfx_detect.sh wet.wav -o dry.wav
21
  ```
22
  ## Download the [General Purpose Audio Effect Removal evaluation datasets](https://zenodo.org/record/8187288)
23
  ```
24
+ scripts/download_eval_datasets.sh
25
  ```
26
 
27
  ## Download the starter datasets
 
73
  First download the General Purpose Audio Effect Removal evaluation datasets (see above).
74
  To use the pretrained RemFX model, download the checkpoints
75
  ```
76
+ scripts/download_checkpoints.sh
77
  ```
78
  Then run the evaluation script, select the RemFX configuration, between `remfx_oracle`, `remfx_detect`, and `remfx_all`. Then select N, the number of effects to remove.
79
  ```
80
+ scripts/eval.sh remfx_detect 0-0
81
+ scripts/eval.sh remfx_detect 1-1
82
+ scripts/eval.sh remfx_detect 2-2
83
+ scripts/eval.sh remfx_detect 3-3
84
+ scripts/eval.sh remfx_detect 4-4
85
+ scripts/eval.sh remfx_detect 5-5
86
 
87
  ```
88
  To eval a custom monolithic model, first train a model (see Training)
89
  Then run the evaluation script, with the config used and checkpoint_path.
90
  ```
91
+ scripts/eval.sh distortion_aug 0-0 -ckpt "logs/ckpts/2023-07-26-10-10-27/epoch\=05-valid_loss\=8.623.ckpt"
92
  ```
93
 
94
  To eval a custom effect-specific model as part of the inference chain, first train a model (see Training), then edit `cfg/exp/remfx_{desired_configuration}.yaml -> ckpts -> {effect}`.
95
  Then run the evaluation script.
96
  ```
97
+ scripts/eval.sh remfx_detect 0-0
98
  ```
99
 
100
  The script assumes that RemFX_eval_datasets is in the top-level directory.
remfx/classifier.py CHANGED
@@ -1,11 +1,9 @@
1
  import torch
2
  import torchaudio
3
  import torch.nn as nn
4
-
5
- # import hearbaseline
6
-
7
- # import hearbaseline.vggish
8
- # import hearbaseline.wav2vec2
9
 
10
  import wav2clip_hear
11
  import panns_hear
@@ -173,10 +171,10 @@ class Cnn14(nn.Module):
173
 
174
  self.fc1 = nn.Linear(2048, 2048, bias=True)
175
 
176
- self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
177
- # self.heads = torch.nn.ModuleList()
178
- # for _ in range(num_classes):
179
- # self.heads.append(nn.Linear(2048, 1, bias=True))
180
 
181
  self.init_weight()
182
 
@@ -192,7 +190,7 @@ class Cnn14(nn.Module):
192
  def init_weight(self):
193
  init_bn(self.bn0)
194
  init_layer(self.fc1)
195
- init_layer(self.fc_audioset)
196
 
197
  def forward(self, x: torch.Tensor, train: bool = False):
198
  """
@@ -212,12 +210,12 @@ class Cnn14(nn.Module):
212
  # axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
213
  # plt.savefig("spec_augment.png", dpi=300)
214
 
215
- x = x.permute(0, 2, 1, 3)
216
- x = self.bn0(x)
217
- x = x.permute(0, 2, 1, 3)
218
 
219
  # apply standardization
220
- # x = (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True)
221
 
222
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
223
  x = F.dropout(x, p=0.2, training=train)
@@ -239,13 +237,13 @@ class Cnn14(nn.Module):
239
  x = F.dropout(x, p=0.5, training=train)
240
  x = F.relu_(self.fc1(x))
241
 
242
- # outputs = []
243
- # for head in self.heads:
244
- # outputs.append(torch.sigmoid(head(x)))
 
 
245
 
246
- clipwise_output = self.fc_audioset(x)
247
- return clipwise_output
248
- # return outputs
249
 
250
 
251
  class ConvBlock(nn.Module):
@@ -296,4 +294,4 @@ class ConvBlock(nn.Module):
296
  else:
297
  raise Exception("Incorrect argument!")
298
 
299
- return x
 
1
  import torch
2
  import torchaudio
3
  import torch.nn as nn
4
+ import hearbaseline
5
+ import hearbaseline.vggish
6
+ import hearbaseline.wav2vec2
 
 
7
 
8
  import wav2clip_hear
9
  import panns_hear
 
171
 
172
  self.fc1 = nn.Linear(2048, 2048, bias=True)
173
 
174
+ # self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
175
+ self.heads = torch.nn.ModuleList()
176
+ for _ in range(num_classes):
177
+ self.heads.append(nn.Linear(2048, 1, bias=True))
178
 
179
  self.init_weight()
180
 
 
190
  def init_weight(self):
191
  init_bn(self.bn0)
192
  init_layer(self.fc1)
193
+ # init_layer(self.fc_audioset)
194
 
195
  def forward(self, x: torch.Tensor, train: bool = False):
196
  """
 
210
  # axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
211
  # plt.savefig("spec_augment.png", dpi=300)
212
 
213
+ # x = x.permute(0, 2, 1, 3)
214
+ # x = self.bn0(x)
215
+ # x = x.permute(0, 2, 1, 3)
216
 
217
  # apply standardization
218
+ x = (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True)
219
 
220
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
221
  x = F.dropout(x, p=0.2, training=train)
 
237
  x = F.dropout(x, p=0.5, training=train)
238
  x = F.relu_(self.fc1(x))
239
 
240
+ outputs = []
241
+ for head in self.heads:
242
+ outputs.append(torch.sigmoid(head(x)))
243
+
244
+ # clipwise_output = self.fc_audioset(x)
245
 
246
+ return outputs
 
 
247
 
248
 
249
  class ConvBlock(nn.Module):
 
294
  else:
295
  raise Exception("Incorrect argument!")
296
 
297
+ return x
remfx/models.py CHANGED
@@ -143,17 +143,8 @@ class RemFXChainInference(pl.LightningModule):
143
  prog_bar=True,
144
  sync_dist=True,
145
  )
146
- # print(f"Input_{metric}", negate * self.metrics[metric](x, y))
147
- # print(f"test_{metric}", negate * self.metrics[metric](output, y))
148
- # self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
149
- # self.output_str += "\n"
150
  return loss
151
 
152
- def on_test_end(self) -> None:
153
- pass
154
- # with open("output.csv", "w") as f:
155
- # f.write(self.output_str)
156
-
157
  def sample(self, batch):
158
  return self.forward(batch, 0)[1]
159
 
@@ -438,7 +429,6 @@ def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
438
 
439
  return mixed_x, mixed_y, lam
440
 
441
-
442
  class FXClassifier(pl.LightningModule):
443
  def __init__(
444
  self,
@@ -458,42 +448,7 @@ class FXClassifier(pl.LightningModule):
458
  self.mixup = mixup
459
  self.label_smoothing = label_smoothing
460
 
461
- self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
462
  self.loss_fn = torch.nn.BCELoss()
463
-
464
- if False:
465
- self.train_f1 = torchmetrics.classification.MultilabelF1Score(
466
- 5, average="none", multidim_average="global"
467
- )
468
- self.val_f1 = torchmetrics.classification.MultilabelF1Score(
469
- 5, average="none", multidim_average="global"
470
- )
471
- self.test_f1 = torchmetrics.classification.MultilabelF1Score(
472
- 5, average="none", multidim_average="global"
473
- )
474
-
475
- self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
476
- 5, threshold=0.5, average="macro", multidim_average="global"
477
- )
478
- self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
479
- 5, threshold=0.5, average="macro", multidim_average="global"
480
- )
481
- self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
482
- 5, threshold=0.5, average="macro", multidim_average="global"
483
- )
484
-
485
- self.metrics = {
486
- "train": self.train_acc,
487
- "valid": self.val_acc,
488
- "test": self.test_acc,
489
- }
490
-
491
- self.avg_metrics = {
492
- "train": self.train_f1_avg,
493
- "valid": self.val_f1_avg,
494
- "test": self.test_f1_avg,
495
- }
496
-
497
  self.metrics = torch.nn.ModuleDict()
498
  for effect in self.effects:
499
  self.metrics[f"train_{effect}_acc"] = torchmetrics.classification.Accuracy(
@@ -578,4 +533,4 @@ class FXClassifier(pl.LightningModule):
578
  lr=self.lr,
579
  weight_decay=self.lr_weight_decay,
580
  )
581
- return optimizer
 
143
  prog_bar=True,
144
  sync_dist=True,
145
  )
 
 
 
 
146
  return loss
147
 
 
 
 
 
 
148
  def sample(self, batch):
149
  return self.forward(batch, 0)[1]
150
 
 
429
 
430
  return mixed_x, mixed_y, lam
431
 
 
432
  class FXClassifier(pl.LightningModule):
433
  def __init__(
434
  self,
 
448
  self.mixup = mixup
449
  self.label_smoothing = label_smoothing
450
 
 
451
  self.loss_fn = torch.nn.BCELoss()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  self.metrics = torch.nn.ModuleDict()
453
  for effect in self.effects:
454
  self.metrics[f"train_{effect}_acc"] = torchmetrics.classification.Accuracy(
 
533
  lr=self.lr,
534
  weight_decay=self.lr_weight_decay,
535
  )
536
+ return optimizer
scripts/chain_inference.py CHANGED
@@ -45,6 +45,7 @@ def main(cfg: DictConfig):
45
 
46
  logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
47
  log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
 
48
  trainer = hydra.utils.instantiate(
49
  cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
50
  )
@@ -68,6 +69,7 @@ def main(cfg: DictConfig):
68
  shuffle_effect_order=cfg.inference_effects_shuffle,
69
  use_all_effect_models=cfg.inference_use_all_effect_models,
70
  )
 
71
  trainer.test(model=inference_model, datamodule=datamodule)
72
 
73
 
 
45
 
46
  logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
47
  log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
48
+ cfg.trainer.accelerator = "gpu" if torch.cuda.is_available() else "cpu"
49
  trainer = hydra.utils.instantiate(
50
  cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
51
  )
 
69
  shuffle_effect_order=cfg.inference_effects_shuffle,
70
  use_all_effect_models=cfg.inference_use_all_effect_models,
71
  )
72
+
73
  trainer.test(model=inference_model, datamodule=datamodule)
74
 
75
 
download_ckpts.sh β†’ scripts/download_ckpts.sh RENAMED
File without changes
download_eval_datasets.sh β†’ scripts/download_eval_datasets.sh RENAMED
File without changes
eval.sh β†’ scripts/eval.sh RENAMED
@@ -1,13 +1,13 @@
1
  #! /bin/bash
2
 
3
  # Example usage:
4
- # ./eval.sh remfx_detect 0-0
5
- # ./eval.sh distortion_aug 0-0 -ckpt logs/ckpts/2023-01-21-12-21-44
6
  # First 2 arguments are required, third argument is optional
7
 
8
  # Default value for the optional parameter
9
  ckpt_path=""
10
-
11
  # Function to display script usage
12
  function display_usage {
13
  echo "Usage: $0 <experiment> <dataset> [-ckpt {ckpt_path}]"
 
1
  #! /bin/bash
2
 
3
  # Example usage:
4
+ # scripts/eval.sh remfx_detect 0-0
5
+ # scripts/eval.sh distortion_aug 0-0 -ckpt logs/ckpts/2023-01-21-12-21-44
6
  # First 2 arguments are required, third argument is optional
7
 
8
  # Default value for the optional parameter
9
  ckpt_path=""
10
+ export DATASET_ROOT=RemFX_eval_datasets
11
  # Function to display script usage
12
  function display_usage {
13
  echo "Usage: $0 <experiment> <dataset> [-ckpt {ckpt_path}]"
remfx_detect.sh β†’ scripts/remfx_detect.sh RENAMED
@@ -1,7 +1,7 @@
1
  #! /bin/bash
2
 
3
  # Example usage:
4
- # ./remfx_detect.sh wet.wav -o examples/output.wav
5
  # first argument is required, second argument is optional
6
 
7
  # Check if first argument is empty
 
1
  #! /bin/bash
2
 
3
  # Example usage:
4
+ # scripts/remfx_detect.sh wet.wav -o examples/output.wav
5
  # first argument is required, second argument is optional
6
 
7
  # Check if first argument is empty