mattricesound commited on
Commit
568c3f1
1 Parent(s): 836d971

Update to latest classifier inference

Browse files
README.md CHANGED
@@ -10,14 +10,19 @@ git clone https://github.com/mhrice/RemFx.git
10
  cd RemFx
11
  git submodule update --init --recursive
12
  pip install -e . ./umx
 
13
  ```
 
14
  # Usage
15
  This repo can be used for many different tasks. Here are some examples.
16
  ## Run RemFX Detect on a single file
17
  First, need to download the checkpoints from [zenodo](https://zenodo.org/record/8179396)
18
  ```
19
  scripts/download_checkpoints.sh
20
- scripts/remfx_detect.sh wet.wav -o dry.wav
 
 
 
21
  ```
22
  ## Download the [General Purpose Audio Effect Removal evaluation datasets](https://zenodo.org/record/8187288)
23
  ```
@@ -69,6 +74,18 @@ If you have generated the dataset separately (see Generate datasets used in the
69
 
70
  Also note that the training assumes you have a GPU. To train on CPU, set `accelerator=null` in the config or command-line.
71
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  ## Evaluate models on the General Purpose Audio Effect Removal evaluation datasets (Table 4 from the paper)
73
  First download the General Purpose Audio Effect Removal evaluation datasets (see above).
74
  To use the pretrained RemFX model, download the checkpoints
@@ -148,26 +165,3 @@ Some relevant dataset/training parameters descriptions
148
  - `distortion`
149
  - `reverb`
150
  - `delay`
151
-
152
- <!-- # DO WE NEED THIS?
153
- ## Evaluate RemFXwith a custom directory
154
- Assumes directory is structured as
155
- - root
156
- - clean
157
- - file1.wav
158
- - file2.wav
159
- - file3.wav
160
- - effected
161
- - file1.wav
162
- - file2.wav
163
- - file3.wav
164
-
165
- First set the dataset root:
166
- ```
167
- export DATASET_ROOT={path/to/datasets}
168
- ```
169
-
170
- Then run
171
- ```
172
- python scripts/chain_inference.py +exp=chain_inference_custom
173
- ``` -->
 
10
  cd RemFx
11
  git submodule update --init --recursive
12
  pip install -e . ./umx
13
+ pip install --no-deps hearbaseline
14
  ```
15
+ Due to incompatabilities with hearbaseline's dependencies (namely numpy/numba) and our other packages, we need to install hearbaseline with no dependencies.
16
  # Usage
17
  This repo can be used for many different tasks. Here are some examples.
18
  ## Run RemFX Detect on a single file
19
  First, need to download the checkpoints from [zenodo](https://zenodo.org/record/8179396)
20
  ```
21
  scripts/download_checkpoints.sh
22
+ ```
23
+ Then run the detect script. This repo contains an example file `example.wav` from our test dataset which contains 2 effects (chorus and delay) applied to a guitar.
24
+ ```
25
+ scripts/remfx_detect.sh example.wav -o dry.wav
26
  ```
27
  ## Download the [General Purpose Audio Effect Removal evaluation datasets](https://zenodo.org/record/8187288)
28
  ```
 
74
 
75
  Also note that the training assumes you have a GPU. To train on CPU, set `accelerator=null` in the config or command-line.
76
 
77
+ ### Logging
78
+ Default CSV logger
79
+ To use WANDB logger:
80
+
81
+ export WANDB_PROJECT={desired_wandb_project}
82
+ export WANDB_ENTITY={your_wandb_username}
83
+
84
+ ## Panns pretrianed
85
+ ```
86
+ wget https://zenodo.org/record/6332525/files/hear2021-panns_hear.pth
87
+ ```
88
+
89
  ## Evaluate models on the General Purpose Audio Effect Removal evaluation datasets (Table 4 from the paper)
90
  First download the General Purpose Audio Effect Removal evaluation datasets (see above).
91
  To use the pretrained RemFX model, download the checkpoints
 
165
  - `distortion`
166
  - `reverb`
167
  - `delay`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/exp/5-5_full_cls.yaml CHANGED
@@ -1,11 +1,11 @@
1
  # @package _global_
2
  defaults:
3
- - override /model: cls_panns_48k
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
- logs_dir: "/scratch/cjs-logs"
9
  render_files: True
10
 
11
  accelerator: "gpu"
 
1
  # @package _global_
2
  defaults:
3
+ - override /model: cls_panns_48k_specaugment
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
  render_files: True
10
 
11
  accelerator: "gpu"
cfg/exp/5-5_full_cls_dynamic.yaml CHANGED
@@ -5,7 +5,7 @@ defaults:
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
- logs_dir: "/scratch/cjs-logs"
9
  render_files: True
10
 
11
  accelerator: "gpu"
 
5
  seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
  render_files: True
10
 
11
  accelerator: "gpu"
remfx/classifier.py CHANGED
@@ -171,7 +171,6 @@ class Cnn14(nn.Module):
171
 
172
  self.fc1 = nn.Linear(2048, 2048, bias=True)
173
 
174
- # self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
175
  self.heads = torch.nn.ModuleList()
176
  for _ in range(num_classes):
177
  self.heads.append(nn.Linear(2048, 1, bias=True))
@@ -190,7 +189,6 @@ class Cnn14(nn.Module):
190
  def init_weight(self):
191
  init_bn(self.bn0)
192
  init_layer(self.fc1)
193
- # init_layer(self.fc_audioset)
194
 
195
  def forward(self, x: torch.Tensor, train: bool = False):
196
  """
@@ -202,20 +200,11 @@ class Cnn14(nn.Module):
202
  x = self.melspec(x)
203
 
204
  if self.specaugment and train:
205
- # import matplotlib.pyplot as plt
206
- # fig, axs = plt.subplots(2, 1, sharex=True)
207
- # axs[0].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
208
  x = self.freq_mask(x)
209
  x = self.time_mask(x)
210
- # axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
211
- # plt.savefig("spec_augment.png", dpi=300)
212
-
213
- # x = x.permute(0, 2, 1, 3)
214
- # x = self.bn0(x)
215
- # x = x.permute(0, 2, 1, 3)
216
 
217
  # apply standardization
218
- x = (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True)
219
 
220
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
221
  x = F.dropout(x, p=0.2, training=train)
@@ -241,8 +230,6 @@ class Cnn14(nn.Module):
241
  for head in self.heads:
242
  outputs.append(torch.sigmoid(head(x)))
243
 
244
- # clipwise_output = self.fc_audioset(x)
245
-
246
  return outputs
247
 
248
 
@@ -294,4 +281,4 @@ class ConvBlock(nn.Module):
294
  else:
295
  raise Exception("Incorrect argument!")
296
 
297
- return x
 
171
 
172
  self.fc1 = nn.Linear(2048, 2048, bias=True)
173
 
 
174
  self.heads = torch.nn.ModuleList()
175
  for _ in range(num_classes):
176
  self.heads.append(nn.Linear(2048, 1, bias=True))
 
189
  def init_weight(self):
190
  init_bn(self.bn0)
191
  init_layer(self.fc1)
 
192
 
193
  def forward(self, x: torch.Tensor, train: bool = False):
194
  """
 
200
  x = self.melspec(x)
201
 
202
  if self.specaugment and train:
 
 
 
203
  x = self.freq_mask(x)
204
  x = self.time_mask(x)
 
 
 
 
 
 
205
 
206
  # apply standardization
207
+ x = (x - x.mean(dim=(2, 3), keepdim=True)) / x.std(dim=(2, 3), keepdim=True)
208
 
209
  x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
210
  x = F.dropout(x, p=0.2, training=train)
 
230
  for head in self.heads:
231
  outputs.append(torch.sigmoid(head(x)))
232
 
 
 
233
  return outputs
234
 
235
 
 
281
  else:
282
  raise Exception("Incorrect argument!")
283
 
284
+ return x
remfx/datasets.py CHANGED
@@ -666,7 +666,7 @@ class EffectDatamodule(pl.LightningDataModule):
666
  def test_dataloader(self) -> DataLoader:
667
  return DataLoader(
668
  dataset=self.test_dataset,
669
- batch_size=1, # Use small, consistent batch size for testing
670
  num_workers=self.num_workers,
671
  pin_memory=self.pin_memory,
672
  shuffle=False,
 
666
  def test_dataloader(self) -> DataLoader:
667
  return DataLoader(
668
  dataset=self.test_dataset,
669
+ batch_size=self.test_batch_size,
670
  num_workers=self.num_workers,
671
  pin_memory=self.pin_memory,
672
  shuffle=False,
remfx/effects.py CHANGED
@@ -84,7 +84,6 @@ def biqaud(
84
  a2 = 1 - alpha / A
85
  else:
86
  pass
87
- # raise ValueError(f"Invalid filter_type: {filter_type}.")
88
 
89
  b = np.array([b0, b1, b2]) / a0
90
  a = np.array([a0, a1, a2]) / a0
@@ -291,7 +290,6 @@ class RandomVolumeAutomation(torch.nn.Module):
291
  gain_db[samples_filled : samples_filled + segment_samples] = fade
292
  samples_filled = samples_filled + segment_samples
293
 
294
- # print(gain_db)
295
  x *= 10 ** (gain_db / 20.0)
296
  return x
297
 
 
84
  a2 = 1 - alpha / A
85
  else:
86
  pass
 
87
 
88
  b = np.array([b0, b1, b2]) / a0
89
  a = np.array([a0, a1, a2]) / a0
 
290
  gain_db[samples_filled : samples_filled + segment_samples] = fade
291
  samples_filled = samples_filled + segment_samples
292
 
 
293
  x *= 10 ** (gain_db / 20.0)
294
  return x
295
 
remfx/models.py CHANGED
@@ -55,12 +55,11 @@ class RemFXChainInference(pl.LightningModule):
55
  effects_order = order
56
  else:
57
  effects_order = self.effect_order
58
-
59
  # Use classifier labels
60
  if self.classifier:
61
  threshold = 0.5
62
  with torch.no_grad():
63
- labels = torch.sigmoid(self.classifier(x))
64
  rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
65
  if self.use_all_effect_models:
66
  effects_present = [
@@ -253,17 +252,8 @@ class RemFX(pl.LightningModule):
253
  prog_bar=True,
254
  sync_dist=True,
255
  )
256
- # print(f"Input_{metric}", negate * self.metrics[metric](x, y))
257
- # print(f"test_{metric}", negate * self.metrics[metric](output, y))
258
- # self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
259
- # self.output_str += "\n"
260
  return loss
261
 
262
- def on_test_end(self) -> None:
263
- pass
264
- # with open("output.csv", "w") as f:
265
- # f.write(self.output_str)
266
-
267
 
268
  class OpenUnmixModel(nn.Module):
269
  def __init__(
@@ -418,7 +408,6 @@ def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
418
  else:
419
  lam = 1
420
 
421
- print(lam)
422
  if np.random.rand() > 0.5:
423
  index = torch.randperm(batch_size).to(x.device)
424
  mixed_x = lam * x + (1 - lam) * x[index, :]
@@ -429,6 +418,7 @@ def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
429
 
430
  return mixed_x, mixed_y, lam
431
 
 
432
  class FXClassifier(pl.LightningModule):
433
  def __init__(
434
  self,
@@ -533,4 +523,4 @@ class FXClassifier(pl.LightningModule):
533
  lr=self.lr,
534
  weight_decay=self.lr_weight_decay,
535
  )
536
- return optimizer
 
55
  effects_order = order
56
  else:
57
  effects_order = self.effect_order
 
58
  # Use classifier labels
59
  if self.classifier:
60
  threshold = 0.5
61
  with torch.no_grad():
62
+ labels = torch.hstack(self.classifier(x))
63
  rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
64
  if self.use_all_effect_models:
65
  effects_present = [
 
252
  prog_bar=True,
253
  sync_dist=True,
254
  )
 
 
 
 
255
  return loss
256
 
 
 
 
 
 
257
 
258
  class OpenUnmixModel(nn.Module):
259
  def __init__(
 
408
  else:
409
  lam = 1
410
 
 
411
  if np.random.rand() > 0.5:
412
  index = torch.randperm(batch_size).to(x.device)
413
  mixed_x = lam * x + (1 - lam) * x[index, :]
 
418
 
419
  return mixed_x, mixed_y, lam
420
 
421
+
422
  class FXClassifier(pl.LightningModule):
423
  def __init__(
424
  self,
 
523
  lr=self.lr,
524
  weight_decay=self.lr_weight_decay,
525
  )
526
+ return optimizer
remfx/tcn.py CHANGED
@@ -91,7 +91,6 @@ class TCN(nn.Module):
91
  self.causal = causal
92
  self.estimate_loudness = estimate_loudness
93
 
94
- print(f"Causal: {self.causal}")
95
  if self.causal:
96
  self.crop_fn = causal_crop
97
  else:
 
91
  self.causal = causal
92
  self.estimate_loudness = estimate_loudness
93
 
 
94
  if self.causal:
95
  self.crop_fn = causal_crop
96
  else:
scripts/test.py CHANGED
@@ -16,7 +16,8 @@ def main(cfg: DictConfig):
16
  datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
17
  log.info(f"Instantiating model <{cfg.model._target_}>.")
18
  model = hydra.utils.instantiate(cfg.model, _convert_="partial")
19
- state_dict = torch.load(cfg.ckpt_path, map_location=torch.device("cpu"))[
 
20
  "state_dict"
21
  ]
22
  model.load_state_dict(state_dict)
 
16
  datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
17
  log.info(f"Instantiating model <{cfg.model._target_}>.")
18
  model = hydra.utils.instantiate(cfg.model, _convert_="partial")
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ state_dict = torch.load(cfg.ckpt_path, map_location=device)[
21
  "state_dict"
22
  ]
23
  model.load_state_dict(state_dict)
setup.py CHANGED
@@ -44,6 +44,15 @@ setup(
44
  "pyloudnorm",
45
  "pedalboard",
46
  "asteroid",
 
 
 
 
 
 
 
 
 
47
  ],
48
  include_package_data=True,
49
  license="Apache License 2.0",
 
44
  "pyloudnorm",
45
  "pedalboard",
46
  "asteroid",
47
+ "librosa",
48
+ "speechbrain",
49
+ "torchcrepe",
50
+ "torchopenl3",
51
+ "tensorflow",
52
+ "transformers",
53
+ "torchmetrics>=1.0",
54
+ "wav2clip_hear @ git+https://github.com/hohsiangwu/wav2clip-hear.git",
55
+ "panns_hear @ git+https://github.com/qiuqiangkong/HEAR2021_Challenge_PANNs",
56
  ],
57
  include_package_data=True,
58
  license="Apache License 2.0",