mattricesound commited on
Commit
3e2f073
2 Parent(s): c1cb017 fb9ce8b

Merge pull request #38 from mhrice/custom-inference

Browse files
README.md CHANGED
@@ -47,6 +47,23 @@ see `cfg/exp/default.yaml` for an example.
47
  - `reverb`
48
  - `delay`
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  ## Misc.
51
  By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
52
 
 
47
  - `reverb`
48
  - `delay`
49
 
50
+ ## Run inference on directory
51
+ Assumes directory is structured as
52
+ - root
53
+ - clean
54
+ - file1.wav
55
+ - file2.wav
56
+ - file3.wav
57
+ - effected
58
+ - file1.wav
59
+ - file2.wav
60
+ - file3.wav
61
+
62
+ Change root path in `shell_vars.sh` and `source shell_vars.sh`
63
+
64
+ `python scripts/chain_inference.py +exp=chain_inference_custom`
65
+
66
+
67
  ## Misc.
68
  By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
69
 
cfg/exp/chain_inference.yaml CHANGED
@@ -28,10 +28,17 @@ datamodule:
28
  num_workers: 8
29
  train_dataset: None
30
  val_dataset: None
 
31
  ckpts:
32
  RandomPedalboardDistortion: "ckpts/distortion.ckpt"
33
  RandomPedalboardCompressor: "ckpts/compressor.ckpt"
34
  RandomPedalboardReverb: "ckpts/reverb.ckpt"
35
  RandomPedalboardChorus: "ckpts/chorus.ckpt"
36
  RandomPedalboardDelay: "ckpts/delay.ckpt"
 
 
 
 
 
 
37
  num_bins: 1025
 
28
  num_workers: 8
29
  train_dataset: None
30
  val_dataset: None
31
+
32
  ckpts:
33
  RandomPedalboardDistortion: "ckpts/distortion.ckpt"
34
  RandomPedalboardCompressor: "ckpts/compressor.ckpt"
35
  RandomPedalboardReverb: "ckpts/reverb.ckpt"
36
  RandomPedalboardChorus: "ckpts/chorus.ckpt"
37
  RandomPedalboardDelay: "ckpts/delay.ckpt"
38
+ inference_effects_ordering:
39
+ - "RandomPedalboardDistortion"
40
+ - "RandomPedalboardCompressor"
41
+ - "RandomPedalboardReverb"
42
+ - "RandomPedalboardChorus"
43
+ - "RandomPedalboardDelay"
44
  num_bins: 1025
cfg/exp/chain_inference_custom.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ render_files: False
10
+ render_root: "/scratch/EffectSet"
11
+ accelerator: "gpu"
12
+ log_audio: True
13
+ # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [0,5] # [min, max]
16
+ shuffle_kept_effects: True
17
+ shuffle_removed_effects: True
18
+ num_classes: 5
19
+ effects_to_keep:
20
+ effects_to_remove:
21
+ - distortion
22
+ - compressor
23
+ - reverb
24
+ - chorus
25
+ - delay
26
+ datamodule:
27
+ batch_size: 1
28
+ num_workers: 8
29
+ train_dataset: None
30
+ val_dataset: None
31
+ test_dataset:
32
+ _target_: remfx.datasets.InferenceDataset
33
+ root: ${oc.env:DATASET_ROOT}
34
+ sample_rate: ${sample_rate}
35
+ ckpts:
36
+ RandomPedalboardDistortion: "ckpts/distortion.ckpt"
37
+ RandomPedalboardCompressor: "ckpts/compressor.ckpt"
38
+ RandomPedalboardReverb: "ckpts/reverb.ckpt"
39
+ RandomPedalboardChorus: "ckpts/chorus.ckpt"
40
+ RandomPedalboardDelay: "ckpts/delay.ckpt"
41
+ inference_effects_ordering:
42
+ - "RandomPedalboardDistortion"
43
+ - "RandomPedalboardCompressor"
44
+ - "RandomPedalboardReverb"
45
+ - "RandomPedalboardChorus"
46
+ - "RandomPedalboardDelay"
47
+ num_bins: 1025
remfx/datasets.py CHANGED
@@ -360,6 +360,41 @@ class EffectDataset(Dataset):
360
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
361
 
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  class EffectDatamodule(pl.LightningDataModule):
364
  def __init__(
365
  self,
 
360
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
361
 
362
 
363
+ class InferenceDataset(Dataset):
364
+ def __init__(self, root: str, sample_rate: int, **kwargs):
365
+ self.root = Path(root)
366
+ self.sample_rate = sample_rate
367
+ self.clean_paths = sorted(list(self.root.glob("clean/*.wav")))
368
+ self.effected_paths = sorted(list(self.root.glob("effected/*.wav")))
369
+
370
+ def __len__(self) -> int:
371
+ return len(self.clean_paths)
372
+
373
+ def __getitem__(self, idx: int) -> torch.Tensor:
374
+ clean_path = self.clean_paths[idx]
375
+ effected_path = self.effected_paths[idx]
376
+ clean_audio, sr = torchaudio.load(clean_path)
377
+ clean = torchaudio.functional.resample(clean_audio, sr, self.sample_rate)
378
+ effected_audio, sr = torchaudio.load(effected_path)
379
+ effected = torchaudio.functional.resample(effected_audio, sr, self.sample_rate)
380
+
381
+ # Sum to mono
382
+ clean = torch.sum(clean, dim=0, keepdim=True)
383
+ effected = torch.sum(effected, dim=0, keepdim=True)
384
+
385
+ # Pad or trim effected to clean
386
+ if effected.shape[1] > clean.shape[1]:
387
+ effected = effected[:, : clean.shape[1]]
388
+ elif effected.shape[1] < clean.shape[1]:
389
+ pad_size = clean.shape[1] - effected.shape[1]
390
+ effected = torch.nn.functional.pad(effected, (0, pad_size))
391
+
392
+ dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
393
+ wet_labels_tensor = torch.ones(len(ALL_EFFECTS))
394
+
395
+ return effected, clean, dry_labels_tensor, wet_labels_tensor
396
+
397
+
398
  class EffectDatamodule(pl.LightningDataModule):
399
  def __init__(
400
  self,
remfx/models.py CHANGED
@@ -37,7 +37,7 @@ class RemFXChainInference(pl.LightningModule):
37
  self.sample_rate = sample_rate
38
  self.effect_order = effect_order
39
 
40
- def forward(self, batch, order=None):
41
  x, y, _, rem_fx_labels = batch
42
  # Use chain of effects defined in config
43
  if order:
@@ -52,25 +52,30 @@ class RemFXChainInference(pl.LightningModule):
52
  with torch.no_grad():
53
  for i, (elem, effects_list) in enumerate(zip(x, effects)):
54
  elem = elem.unsqueeze(0) # Add batch dim
55
- effect_chain_idx = [
56
- effects_order.index(effect.__name__) for effect in effects_list
 
 
 
 
57
  ]
 
58
  # log_wandb_audio_batch(
59
  # logger=self.logger,
60
- # id=f"{i}_Before",
61
  # samples=elem.cpu(),
62
  # sampling_rate=self.sample_rate,
63
- # caption=effect_chain,
64
  # )
65
- for idx in effect_chain_idx:
66
  # Sample the model
67
- elem = self.model[effects_order[idx]].model.sample(elem)
68
  # log_wandb_audio_batch(
69
  # logger=self.logger,
70
- # id=f"{i}_{effect}",
71
  # samples=elem.cpu(),
72
  # sampling_rate=self.sample_rate,
73
- # caption=effect_chain,
74
  # )
75
  output.append(elem.squeeze(0))
76
  output = torch.stack(output)
@@ -81,8 +86,8 @@ class RemFXChainInference(pl.LightningModule):
81
  def test_step(self, batch, batch_idx):
82
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
83
  # Random order
84
- random.shuffle(self.effect_order)
85
- loss, output = self.forward(batch, order=self.effect_order)
86
  # Crop target to match output
87
  if output.shape[-1] < y.shape[-1]:
88
  y = causal_crop(y, output.shape[-1])
@@ -96,7 +101,7 @@ class RemFXChainInference(pl.LightningModule):
96
  else:
97
  negate = 1
98
  self.log(
99
- f"test_{metric}_" + "".join(self.effect_order),
100
  negate * self.metrics[metric](output, y),
101
  on_step=False,
102
  on_epoch=True,
@@ -307,27 +312,6 @@ class DPTNetModel(nn.Module):
307
  def sample(self, x: Tensor) -> Tensor:
308
  return self.model(x.squeeze(1))
309
 
310
- def __init__(self, sample_rate, num_bins, **kwargs):
311
- super().__init__()
312
- self.model = asteroid.models.DCUNet(**kwargs)
313
- self.mrstftloss = MultiResolutionSTFTLoss(
314
- n_bins=num_bins, sample_rate=sample_rate
315
- )
316
- self.l1loss = nn.L1Loss()
317
-
318
- def forward(self, batch):
319
- x, target = batch
320
- output = self.model(x.squeeze(1)) # B x T
321
- # Crop target to match output
322
- if output.shape[-1] < target.shape[-1]:
323
- target = causal_crop(target, output.shape[-1])
324
- loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
325
- return loss, output
326
-
327
- def sample(self, x: Tensor) -> Tensor:
328
- output = self.model(x.squeeze(1)) # B x T
329
- return output
330
-
331
 
332
  class TCNModel(nn.Module):
333
  def __init__(self, sample_rate, num_bins, **kwargs):
 
37
  self.sample_rate = sample_rate
38
  self.effect_order = effect_order
39
 
40
+ def forward(self, batch, batch_idx, order=None):
41
  x, y, _, rem_fx_labels = batch
42
  # Use chain of effects defined in config
43
  if order:
 
52
  with torch.no_grad():
53
  for i, (elem, effects_list) in enumerate(zip(x, effects)):
54
  elem = elem.unsqueeze(0) # Add batch dim
55
+ # effect_chain_idx = [
56
+ # effects_order.index(effect.__name__) for effect in effects_list
57
+ # ]
58
+ effect_list_names = [effect.__name__ for effect in effects_list]
59
+ effects = [
60
+ effect for effect in effects_order if effect in effect_list_names
61
  ]
62
+
63
  # log_wandb_audio_batch(
64
  # logger=self.logger,
65
+ # id=f"{batch_idx}_{i}_Before",
66
  # samples=elem.cpu(),
67
  # sampling_rate=self.sample_rate,
68
+ # caption=effects,
69
  # )
70
+ for effect in effects:
71
  # Sample the model
72
+ elem = self.model[effect].model.sample(elem)
73
  # log_wandb_audio_batch(
74
  # logger=self.logger,
75
+ # id=f"{batch_idx}_{i}_{effect}",
76
  # samples=elem.cpu(),
77
  # sampling_rate=self.sample_rate,
78
+ # caption=effects,
79
  # )
80
  output.append(elem.squeeze(0))
81
  output = torch.stack(output)
 
86
  def test_step(self, batch, batch_idx):
87
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
88
  # Random order
89
+ # random.shuffle(self.effect_order)
90
+ loss, output = self.forward(batch, batch_idx, order=self.effect_order)
91
  # Crop target to match output
92
  if output.shape[-1] < y.shape[-1]:
93
  y = causal_crop(y, output.shape[-1])
 
101
  else:
102
  negate = 1
103
  self.log(
104
+ f"test_{metric}", # + "".join(self.effect_order).replace("RandomPedalboard", ""),
105
  negate * self.metrics[metric](output, y),
106
  on_step=False,
107
  on_epoch=True,
 
312
  def sample(self, x: Tensor) -> Tensor:
313
  return self.model(x.squeeze(1))
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  class TCNModel(nn.Module):
317
  def __init__(self, sample_rate, num_bins, **kwargs):
scripts/chain_inference.py CHANGED
@@ -47,17 +47,12 @@ def main(cfg: DictConfig):
47
  logger=logger,
48
  )
49
 
 
50
  inference_model = RemFXChainInference(
51
  models,
52
  sample_rate=cfg.sample_rate,
53
  num_bins=cfg.num_bins,
54
- effect_order=[
55
- "RandomPedalboardDistortion",
56
- "RandomPedalboardCompressor",
57
- "RandomPedalboardReverb",
58
- "RandomPedalboardChorus",
59
- "RandomPedalboardDelay",
60
- ],
61
  )
62
  trainer.test(model=inference_model, datamodule=datamodule)
63
 
 
47
  logger=logger,
48
  )
49
 
50
+ log.info("Instantiating Inference Model")
51
  inference_model = RemFXChainInference(
52
  models,
53
  sample_rate=cfg.sample_rate,
54
  num_bins=cfg.num_bins,
55
+ effect_order=cfg.inference_effects_ordering,
 
 
 
 
 
 
56
  )
57
  trainer.test(model=inference_model, datamodule=datamodule)
58