mattricesound commited on
Commit
fb9ce8b
1 Parent(s): 6da1b0d

Fix custom inferencing issues

Browse files
README.md CHANGED
@@ -47,6 +47,23 @@ see `cfg/exp/default.yaml` for an example.
47
  - `reverb`
48
  - `delay`
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  ## Misc.
51
  By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
52
 
 
47
  - `reverb`
48
  - `delay`
49
 
50
+ ## Run inference on directory
51
+ Assumes directory is structured as
52
+ - root
53
+ - clean
54
+ - file1.wav
55
+ - file2.wav
56
+ - file3.wav
57
+ - effected
58
+ - file1.wav
59
+ - file2.wav
60
+ - file3.wav
61
+
62
+ Change root path in `shell_vars.sh` and `source shell_vars.sh`
63
+
64
+ `python scripts/chain_inference.py +exp=chain_inference_custom`
65
+
66
+
67
  ## Misc.
68
  By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
69
 
cfg/exp/chain_inference.yaml CHANGED
@@ -28,10 +28,17 @@ datamodule:
28
  num_workers: 8
29
  train_dataset: None
30
  val_dataset: None
 
31
  ckpts:
32
  RandomPedalboardDistortion: "ckpts/distortion.ckpt"
33
  RandomPedalboardCompressor: "ckpts/compressor.ckpt"
34
  RandomPedalboardReverb: "ckpts/reverb.ckpt"
35
  RandomPedalboardChorus: "ckpts/chorus.ckpt"
36
  RandomPedalboardDelay: "ckpts/delay.ckpt"
 
 
 
 
 
 
37
  num_bins: 1025
 
28
  num_workers: 8
29
  train_dataset: None
30
  val_dataset: None
31
+
32
  ckpts:
33
  RandomPedalboardDistortion: "ckpts/distortion.ckpt"
34
  RandomPedalboardCompressor: "ckpts/compressor.ckpt"
35
  RandomPedalboardReverb: "ckpts/reverb.ckpt"
36
  RandomPedalboardChorus: "ckpts/chorus.ckpt"
37
  RandomPedalboardDelay: "ckpts/delay.ckpt"
38
+ inference_effects_ordering:
39
+ - "RandomPedalboardDistortion"
40
+ - "RandomPedalboardCompressor"
41
+ - "RandomPedalboardReverb"
42
+ - "RandomPedalboardChorus"
43
+ - "RandomPedalboardDelay"
44
  num_bins: 1025
cfg/exp/chain_inference_custom.yaml CHANGED
@@ -24,13 +24,13 @@ effects_to_remove:
24
  - chorus
25
  - delay
26
  datamodule:
27
- batch_size: 16
28
  num_workers: 8
29
  train_dataset: None
30
  val_dataset: None
31
  test_dataset:
32
  _target_: remfx.datasets.InferenceDataset
33
- root: "./data/fx-examples"
34
  sample_rate: ${sample_rate}
35
  ckpts:
36
  RandomPedalboardDistortion: "ckpts/distortion.ckpt"
@@ -38,4 +38,10 @@ ckpts:
38
  RandomPedalboardReverb: "ckpts/reverb.ckpt"
39
  RandomPedalboardChorus: "ckpts/chorus.ckpt"
40
  RandomPedalboardDelay: "ckpts/delay.ckpt"
 
 
 
 
 
 
41
  num_bins: 1025
 
24
  - chorus
25
  - delay
26
  datamodule:
27
+ batch_size: 1
28
  num_workers: 8
29
  train_dataset: None
30
  val_dataset: None
31
  test_dataset:
32
  _target_: remfx.datasets.InferenceDataset
33
+ root: ${oc.env:DATASET_ROOT}
34
  sample_rate: ${sample_rate}
35
  ckpts:
36
  RandomPedalboardDistortion: "ckpts/distortion.ckpt"
 
38
  RandomPedalboardReverb: "ckpts/reverb.ckpt"
39
  RandomPedalboardChorus: "ckpts/chorus.ckpt"
40
  RandomPedalboardDelay: "ckpts/delay.ckpt"
41
+ inference_effects_ordering:
42
+ - "RandomPedalboardDistortion"
43
+ - "RandomPedalboardCompressor"
44
+ - "RandomPedalboardReverb"
45
+ - "RandomPedalboardChorus"
46
+ - "RandomPedalboardDelay"
47
  num_bins: 1025
remfx/datasets.py CHANGED
@@ -361,14 +361,14 @@ class EffectDataset(Dataset):
361
 
362
 
363
  class InferenceDataset(Dataset):
364
- def __init__(self, root: str, sample_rate: int):
365
  self.root = Path(root)
366
  self.sample_rate = sample_rate
367
- self.clean_paths = list(self.root.glob("clean/*.wav"))
368
- self.effected_paths = list(self.root.glob("effected/*.wav"))
369
 
370
  def __len__(self) -> int:
371
- return len(self.audio_paths)
372
 
373
  def __getitem__(self, idx: int) -> torch.Tensor:
374
  clean_path = self.clean_paths[idx]
@@ -379,21 +379,20 @@ class InferenceDataset(Dataset):
379
  effected = torchaudio.functional.resample(effected_audio, sr, self.sample_rate)
380
 
381
  # Sum to mono
382
- clean = torch.sum(clean, dim=0)
383
- effected = torch.sum(effected, dim=0)
384
 
385
  # Pad or trim effected to clean
386
- if len(clean) > len(effected):
387
- effected = torch.nn.functional.pad(
388
- effected, (0, len(clean) - len(effected))
389
- )
390
- elif len(effected) > len(clean):
391
- effected = effected[: len(clean)]
392
 
393
  dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
394
  wet_labels_tensor = torch.ones(len(ALL_EFFECTS))
395
 
396
- return clean, effected, dry_labels_tensor, wet_labels_tensor
397
 
398
 
399
  class EffectDatamodule(pl.LightningDataModule):
 
361
 
362
 
363
  class InferenceDataset(Dataset):
364
+ def __init__(self, root: str, sample_rate: int, **kwargs):
365
  self.root = Path(root)
366
  self.sample_rate = sample_rate
367
+ self.clean_paths = sorted(list(self.root.glob("clean/*.wav")))
368
+ self.effected_paths = sorted(list(self.root.glob("effected/*.wav")))
369
 
370
  def __len__(self) -> int:
371
+ return len(self.clean_paths)
372
 
373
  def __getitem__(self, idx: int) -> torch.Tensor:
374
  clean_path = self.clean_paths[idx]
 
379
  effected = torchaudio.functional.resample(effected_audio, sr, self.sample_rate)
380
 
381
  # Sum to mono
382
+ clean = torch.sum(clean, dim=0, keepdim=True)
383
+ effected = torch.sum(effected, dim=0, keepdim=True)
384
 
385
  # Pad or trim effected to clean
386
+ if effected.shape[1] > clean.shape[1]:
387
+ effected = effected[:, : clean.shape[1]]
388
+ elif effected.shape[1] < clean.shape[1]:
389
+ pad_size = clean.shape[1] - effected.shape[1]
390
+ effected = torch.nn.functional.pad(effected, (0, pad_size))
 
391
 
392
  dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
393
  wet_labels_tensor = torch.ones(len(ALL_EFFECTS))
394
 
395
+ return effected, clean, dry_labels_tensor, wet_labels_tensor
396
 
397
 
398
  class EffectDatamodule(pl.LightningDataModule):
remfx/models.py CHANGED
@@ -37,7 +37,7 @@ class RemFXChainInference(pl.LightningModule):
37
  self.sample_rate = sample_rate
38
  self.effect_order = effect_order
39
 
40
- def forward(self, batch, order=None):
41
  x, y, _, rem_fx_labels = batch
42
  # Use chain of effects defined in config
43
  if order:
@@ -52,25 +52,30 @@ class RemFXChainInference(pl.LightningModule):
52
  with torch.no_grad():
53
  for i, (elem, effects_list) in enumerate(zip(x, effects)):
54
  elem = elem.unsqueeze(0) # Add batch dim
55
- effect_chain_idx = [
56
- effects_order.index(effect.__name__) for effect in effects_list
 
 
 
 
57
  ]
 
58
  # log_wandb_audio_batch(
59
  # logger=self.logger,
60
- # id=f"{i}_Before",
61
  # samples=elem.cpu(),
62
  # sampling_rate=self.sample_rate,
63
- # caption=effect_chain,
64
  # )
65
- for idx in effect_chain_idx:
66
  # Sample the model
67
- elem = self.model[effects_order[idx]].model.sample(elem)
68
  # log_wandb_audio_batch(
69
  # logger=self.logger,
70
- # id=f"{i}_{effect}",
71
  # samples=elem.cpu(),
72
  # sampling_rate=self.sample_rate,
73
- # caption=effect_chain,
74
  # )
75
  output.append(elem.squeeze(0))
76
  output = torch.stack(output)
@@ -81,8 +86,8 @@ class RemFXChainInference(pl.LightningModule):
81
  def test_step(self, batch, batch_idx):
82
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
83
  # Random order
84
- random.shuffle(self.effect_order)
85
- loss, output = self.forward(batch, order=self.effect_order)
86
  # Crop target to match output
87
  if output.shape[-1] < y.shape[-1]:
88
  y = causal_crop(y, output.shape[-1])
@@ -96,8 +101,7 @@ class RemFXChainInference(pl.LightningModule):
96
  else:
97
  negate = 1
98
  self.log(
99
- f"test_{metric}_"
100
- + "".join(self.effect_order).replace("RandomPedalboard", ""),
101
  negate * self.metrics[metric](output, y),
102
  on_step=False,
103
  on_epoch=True,
 
37
  self.sample_rate = sample_rate
38
  self.effect_order = effect_order
39
 
40
+ def forward(self, batch, batch_idx, order=None):
41
  x, y, _, rem_fx_labels = batch
42
  # Use chain of effects defined in config
43
  if order:
 
52
  with torch.no_grad():
53
  for i, (elem, effects_list) in enumerate(zip(x, effects)):
54
  elem = elem.unsqueeze(0) # Add batch dim
55
+ # effect_chain_idx = [
56
+ # effects_order.index(effect.__name__) for effect in effects_list
57
+ # ]
58
+ effect_list_names = [effect.__name__ for effect in effects_list]
59
+ effects = [
60
+ effect for effect in effects_order if effect in effect_list_names
61
  ]
62
+
63
  # log_wandb_audio_batch(
64
  # logger=self.logger,
65
+ # id=f"{batch_idx}_{i}_Before",
66
  # samples=elem.cpu(),
67
  # sampling_rate=self.sample_rate,
68
+ # caption=effects,
69
  # )
70
+ for effect in effects:
71
  # Sample the model
72
+ elem = self.model[effect].model.sample(elem)
73
  # log_wandb_audio_batch(
74
  # logger=self.logger,
75
+ # id=f"{batch_idx}_{i}_{effect}",
76
  # samples=elem.cpu(),
77
  # sampling_rate=self.sample_rate,
78
+ # caption=effects,
79
  # )
80
  output.append(elem.squeeze(0))
81
  output = torch.stack(output)
 
86
  def test_step(self, batch, batch_idx):
87
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
88
  # Random order
89
+ # random.shuffle(self.effect_order)
90
+ loss, output = self.forward(batch, batch_idx, order=self.effect_order)
91
  # Crop target to match output
92
  if output.shape[-1] < y.shape[-1]:
93
  y = causal_crop(y, output.shape[-1])
 
101
  else:
102
  negate = 1
103
  self.log(
104
+ f"test_{metric}", # + "".join(self.effect_order).replace("RandomPedalboard", ""),
 
105
  negate * self.metrics[metric](output, y),
106
  on_step=False,
107
  on_epoch=True,
scripts/chain_inference.py CHANGED
@@ -47,17 +47,12 @@ def main(cfg: DictConfig):
47
  logger=logger,
48
  )
49
 
 
50
  inference_model = RemFXChainInference(
51
  models,
52
  sample_rate=cfg.sample_rate,
53
  num_bins=cfg.num_bins,
54
- effect_order=[
55
- "RandomPedalboardDistortion",
56
- "RandomPedalboardCompressor",
57
- "RandomPedalboardReverb",
58
- "RandomPedalboardChorus",
59
- "RandomPedalboardDelay",
60
- ],
61
  )
62
  trainer.test(model=inference_model, datamodule=datamodule)
63
 
 
47
  logger=logger,
48
  )
49
 
50
+ log.info("Instantiating Inference Model")
51
  inference_model = RemFXChainInference(
52
  models,
53
  sample_rate=cfg.sample_rate,
54
  num_bins=cfg.num_bins,
55
+ effect_order=cfg.inference_effects_ordering,
 
 
 
 
 
 
56
  )
57
  trainer.test(model=inference_model, datamodule=datamodule)
58