mattricesound commited on
Commit
9eba2f5
1 Parent(s): 80a1624

Fix circular import issue

Browse files
cfg/exp/chain_inference.yaml CHANGED
@@ -6,7 +6,7 @@ seed: 12345
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
- render_files: True
10
  render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
@@ -26,10 +26,12 @@ effects_to_remove:
26
  datamodule:
27
  batch_size: 16
28
  num_workers: 8
 
 
29
  ckpts:
30
- RandomPedalboardChorus: "ckpts/chorus.ckpt"
31
- RandomPedalboardDelay: "ckpts/delay.ckpt"
32
  RandomPedalboardDistortion: "ckpts/distortion.ckpt"
33
  RandomPedalboardCompressor: "ckpts/compressor.ckpt"
34
  RandomPedalboardReverb: "ckpts/reverb.ckpt"
 
 
35
  num_bins: 1025
 
6
  sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
+ render_files: False
10
  render_root: "/scratch/EffectSet"
11
  accelerator: "gpu"
12
  log_audio: True
 
26
  datamodule:
27
  batch_size: 16
28
  num_workers: 8
29
+ train_dataset: None
30
+ val_dataset: None
31
  ckpts:
 
 
32
  RandomPedalboardDistortion: "ckpts/distortion.ckpt"
33
  RandomPedalboardCompressor: "ckpts/compressor.ckpt"
34
  RandomPedalboardReverb: "ckpts/reverb.ckpt"
35
+ RandomPedalboardChorus: "ckpts/chorus.ckpt"
36
+ RandomPedalboardDelay: "ckpts/delay.ckpt"
37
  num_bins: 1025
remfx/callbacks.py CHANGED
@@ -4,7 +4,6 @@ from einops import rearrange
4
  import torch
5
  import wandb
6
  from torch import Tensor
7
- from remfx.models import RemFXChainInference
8
 
9
 
10
  class AudioCallback(Callback):
@@ -47,6 +46,9 @@ class AudioCallback(Callback):
47
  # Only run on first batch
48
  if batch_idx == 0 and self.log_audio:
49
  with torch.no_grad():
 
 
 
50
  if type(pl_module) == RemFXChainInference:
51
  y = pl_module.sample(batch)
52
  else:
 
4
  import torch
5
  import wandb
6
  from torch import Tensor
 
7
 
8
 
9
  class AudioCallback(Callback):
 
46
  # Only run on first batch
47
  if batch_idx == 0 and self.log_audio:
48
  with torch.no_grad():
49
+ # Avoids circular import
50
+ from remfx.models import RemFXChainInference
51
+
52
  if type(pl_module) == RemFXChainInference:
53
  y = pl_module.sample(batch)
54
  else:
remfx/models.py CHANGED
@@ -14,12 +14,13 @@ from remfx.utils import causal_crop
14
  from remfx.callbacks import log_wandb_audio_batch
15
  from remfx import effects
16
  import asteroid
 
17
 
18
  ALL_EFFECTS = effects.Pedalboard_Effects
19
 
20
 
21
  class RemFXChainInference(pl.LightningModule):
22
- def __init__(self, models, sample_rate, num_bins):
23
  super().__init__()
24
  self.model = models
25
  self.mrstftloss = MultiResolutionSTFTLoss(
@@ -33,36 +34,45 @@ class RemFXChainInference(pl.LightningModule):
33
  "FAD": FADLoss(sample_rate=sample_rate),
34
  }
35
  )
 
 
36
 
37
- def forward(self, batch):
38
  x, y, _, rem_fx_labels = batch
39
  # Use chain of effects defined in config
 
 
 
 
40
  effects = [
41
  [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
42
  for effect_label in rem_fx_labels
43
  ]
44
  output = []
45
  with torch.no_grad():
46
- for i, (elem, effect_chain) in enumerate(zip(x, effects)):
47
  elem = elem.unsqueeze(0) # Add batch dim
48
- log_wandb_audio_batch(
49
- logger=self.logger,
50
- id=f"{i}_Before",
51
- samples=elem.cpu(),
52
- sampling_rate=self.sample_rate,
53
- caption=effect_chain,
54
- )
 
 
 
 
55
  for effect in effect_chain:
56
- # Get correct model based on effect name. This is a bit hacky
57
- # Then sample the model
58
- elem = self.model[effect.__name__].model.sample(elem)
59
- log_wandb_audio_batch(
60
- logger=self.logger,
61
- id=f"{i}_{effect}",
62
- samples=elem.cpu(),
63
- sampling_rate=self.sample_rate,
64
- caption=effect_chain,
65
- )
66
  output.append(elem.squeeze(0))
67
  output = torch.stack(output)
68
 
@@ -71,8 +81,9 @@ class RemFXChainInference(pl.LightningModule):
71
 
72
  def test_step(self, batch, batch_idx):
73
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
74
-
75
- loss, output = self.forward(batch)
 
76
  # Crop target to match output
77
  if output.shape[-1] < y.shape[-1]:
78
  y = causal_crop(y, output.shape[-1])
@@ -86,7 +97,7 @@ class RemFXChainInference(pl.LightningModule):
86
  else:
87
  negate = 1
88
  self.log(
89
- f"test_{metric}",
90
  negate * self.metrics[metric](output, y),
91
  on_step=False,
92
  on_epoch=True,
 
14
  from remfx.callbacks import log_wandb_audio_batch
15
  from remfx import effects
16
  import asteroid
17
+ import random
18
 
19
  ALL_EFFECTS = effects.Pedalboard_Effects
20
 
21
 
22
  class RemFXChainInference(pl.LightningModule):
23
+ def __init__(self, models, sample_rate, num_bins, effect_order):
24
  super().__init__()
25
  self.model = models
26
  self.mrstftloss = MultiResolutionSTFTLoss(
 
34
  "FAD": FADLoss(sample_rate=sample_rate),
35
  }
36
  )
37
+ self.sample_rate = sample_rate
38
+ self.effect_order = effect_order
39
 
40
+ def forward(self, batch, order=None):
41
  x, y, _, rem_fx_labels = batch
42
  # Use chain of effects defined in config
43
+ if order:
44
+ effects_order = order
45
+ else:
46
+ effects_order = self.effect_order
47
  effects = [
48
  [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
49
  for effect_label in rem_fx_labels
50
  ]
51
  output = []
52
  with torch.no_grad():
53
+ for i, (elem, effects_list) in enumerate(zip(x, effects)):
54
  elem = elem.unsqueeze(0) # Add batch dim
55
+ effect_chain = [
56
+ effects_order.index(effect.__name__) for effect in effects_list
57
+ ]
58
+ # log_wandb_audio_batch(
59
+ # logger=self.logger,
60
+ # id=f"{i}_Before",
61
+ # samples=elem.cpu(),
62
+ # sampling_rate=self.sample_rate,
63
+ # caption=effect_chain,
64
+ # )
65
+ effect_chain
66
  for effect in effect_chain:
67
+ # Sample the model
68
+ elem = self.model[effect].model.sample(elem)
69
+ # log_wandb_audio_batch(
70
+ # logger=self.logger,
71
+ # id=f"{i}_{effect}",
72
+ # samples=elem.cpu(),
73
+ # sampling_rate=self.sample_rate,
74
+ # caption=effect_chain,
75
+ # )
 
76
  output.append(elem.squeeze(0))
77
  output = torch.stack(output)
78
 
 
81
 
82
  def test_step(self, batch, batch_idx):
83
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
84
+ # Random order
85
+ order = random.shuffle(self.effect_order)
86
+ loss, output = self.forward(batch, order=order)
87
  # Crop target to match output
88
  if output.shape[-1] < y.shape[-1]:
89
  y = causal_crop(y, output.shape[-1])
 
97
  else:
98
  negate = 1
99
  self.log(
100
+ f"test_{metric}_" + "".join(order),
101
  negate * self.metrics[metric](output, y),
102
  on_step=False,
103
  on_epoch=True,
scripts/chain_inference.py CHANGED
@@ -22,7 +22,7 @@ def main(cfg: DictConfig):
22
  model = hydra.utils.instantiate(cfg.model, _convert_="partial")
23
  state_dict = torch.load(ckpt_path)["state_dict"]
24
  model.load_state_dict(state_dict)
25
- model.to(cfg.device)
26
  models[effect] = model
27
 
28
  callbacks = []
@@ -48,7 +48,10 @@ def main(cfg: DictConfig):
48
  )
49
 
50
  inference_model = RemFXChainInference(
51
- models, sample_rate=cfg.sample_rate, num_bins=cfg.num_bins
 
 
 
52
  )
53
  trainer.test(model=inference_model, datamodule=datamodule)
54
 
 
22
  model = hydra.utils.instantiate(cfg.model, _convert_="partial")
23
  state_dict = torch.load(ckpt_path)["state_dict"]
24
  model.load_state_dict(state_dict)
25
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
26
  models[effect] = model
27
 
28
  callbacks = []
 
48
  )
49
 
50
  inference_model = RemFXChainInference(
51
+ models,
52
+ sample_rate=cfg.sample_rate,
53
+ num_bins=cfg.num_bins,
54
+ order=["Distortion", "Compressor", "Reverb", "Chorus", "Delay"],
55
  )
56
  trainer.test(model=inference_model, datamodule=datamodule)
57