mattricesound commited on
Commit
ace4057
1 Parent(s): e8eaf47

Add custom model choice for chain inference

Browse files
cfg/exp/chain_inference.yaml CHANGED
@@ -26,12 +26,37 @@ datamodule:
26
  batch_size: 16
27
  num_workers: 8
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ckpts:
30
- RandomPedalboardDistortion: "ckpts/demucs_distortion.ckpt"
31
- RandomPedalboardCompressor: "ckpts/demucs_compressor.ckpt"
32
- RandomPedalboardReverb: "ckpts/demucs_reverb.ckpt"
33
- RandomPedalboardChorus: "ckpts/demucs_chorus.ckpt"
34
- RandomPedalboardDelay: "ckpts/demucs_delay.ckpt"
 
 
 
 
 
 
 
 
 
 
35
  inference_effects_ordering:
36
  - "RandomPedalboardDistortion"
37
  - "RandomPedalboardCompressor"
 
26
  batch_size: 16
27
  num_workers: 8
28
 
29
+ dcunet:
30
+ _target_: remfx.models.RemFX
31
+ lr: 1e-4
32
+ lr_beta1: 0.95
33
+ lr_beta2: 0.999
34
+ lr_eps: 1e-6
35
+ lr_weight_decay: 1e-3
36
+ sample_rate: ${sample_rate}
37
+ network:
38
+ _target_: remfx.models.DCUNetModel
39
+ architecture: "Large-DCUNet-20"
40
+ stft_kernel_size: 512
41
+ fix_length_mode: "pad"
42
+ sample_rate: ${sample_rate}
43
+ num_bins: 1025
44
  ckpts:
45
+ RandomPedalboardDistortion:
46
+ model: ${model}
47
+ ckpt_path: "ckpts/demucs_distortion.ckpt"
48
+ RandomPedalboardCompressor:
49
+ model: ${model}
50
+ ckpt_path: "ckpts/demucs_compressor.ckpt"
51
+ RandomPedalboardReverb:
52
+ model: ${dcunet}
53
+ ckpt_path: "ckpts/dcunet_reverb.ckpt"
54
+ RandomPedalboardChorus:
55
+ model: ${dcunet}
56
+ ckpt_path: "ckpts/dcunet_chorus.ckpt"
57
+ RandomPedalboardDelay:
58
+ model: ${dcunet}
59
+ ckpt_path: "ckpts/dcunet_delay.ckpt"
60
  inference_effects_ordering:
61
  - "RandomPedalboardDistortion"
62
  - "RandomPedalboardCompressor"
cfg/exp/chain_inference_aug.yaml CHANGED
@@ -26,12 +26,37 @@ datamodule:
26
  batch_size: 16
27
  num_workers: 8
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ckpts:
30
- RandomPedalboardDistortion: "ckpts/demucs_distortion_aug.ckpt"
31
- RandomPedalboardCompressor: "ckpts/demucs_compressor_aug.ckpt"
32
- RandomPedalboardReverb: "ckpts/demucs_reverb_aug.ckpt"
33
- RandomPedalboardChorus: "ckpts/demucs_chorus_aug.ckpt"
34
- RandomPedalboardDelay: "ckpts/demucs_delay_aug.ckpt"
 
 
 
 
 
 
 
 
 
 
35
  inference_effects_ordering:
36
  - "RandomPedalboardDistortion"
37
  - "RandomPedalboardCompressor"
 
26
  batch_size: 16
27
  num_workers: 8
28
 
29
+ dcunet:
30
+ _target_: remfx.models.RemFX
31
+ lr: 1e-4
32
+ lr_beta1: 0.95
33
+ lr_beta2: 0.999
34
+ lr_eps: 1e-6
35
+ lr_weight_decay: 1e-3
36
+ sample_rate: ${sample_rate}
37
+ network:
38
+ _target_: remfx.models.DCUNetModel
39
+ architecture: "Large-DCUNet-20"
40
+ stft_kernel_size: 512
41
+ fix_length_mode: "pad"
42
+ sample_rate: ${sample_rate}
43
+ num_bins: 1025
44
  ckpts:
45
+ RandomPedalboardDistortion:
46
+ model: ${model}
47
+ ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
48
+ RandomPedalboardCompressor:
49
+ model: ${model}
50
+ ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
51
+ RandomPedalboardReverb:
52
+ model: ${dcunet}
53
+ ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
54
+ RandomPedalboardChorus:
55
+ model: ${dcunet}
56
+ ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
57
+ RandomPedalboardDelay:
58
+ model: ${dcunet}
59
+ ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
60
  inference_effects_ordering:
61
  - "RandomPedalboardDistortion"
62
  - "RandomPedalboardCompressor"
cfg/exp/chain_inference_custom.yaml CHANGED
@@ -31,12 +31,37 @@ datamodule:
31
  _target_: remfx.datasets.InferenceDataset
32
  root: ${oc.env:DATASET_ROOT}
33
  sample_rate: ${sample_rate}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ckpts:
35
- RandomPedalboardDistortion: "ckpts/demucs_distortion_aug.ckpt"
36
- RandomPedalboardCompressor: "ckpts/demucs_compressor_aug.ckpt"
37
- RandomPedalboardReverb: "ckpts/demucs_reverb_aug.ckpt"
38
- RandomPedalboardChorus: "ckpts/demucs_chorus_aug.ckpt"
39
- RandomPedalboardDelay: "ckpts/demucs_delay_aug.ckpt"
 
 
 
 
 
 
 
 
 
 
40
  inference_effects_ordering:
41
  - "RandomPedalboardDistortion"
42
  - "RandomPedalboardCompressor"
 
31
  _target_: remfx.datasets.InferenceDataset
32
  root: ${oc.env:DATASET_ROOT}
33
  sample_rate: ${sample_rate}
34
+ dcunet:
35
+ _target_: remfx.models.RemFX
36
+ lr: 1e-4
37
+ lr_beta1: 0.95
38
+ lr_beta2: 0.999
39
+ lr_eps: 1e-6
40
+ lr_weight_decay: 1e-3
41
+ sample_rate: ${sample_rate}
42
+ network:
43
+ _target_: remfx.models.DCUNetModel
44
+ architecture: "Large-DCUNet-20"
45
+ stft_kernel_size: 512
46
+ fix_length_mode: "pad"
47
+ sample_rate: ${sample_rate}
48
+ num_bins: 1025
49
  ckpts:
50
+ RandomPedalboardDistortion:
51
+ model: ${model}
52
+ ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
53
+ RandomPedalboardCompressor:
54
+ model: ${model}
55
+ ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
56
+ RandomPedalboardReverb:
57
+ model: ${dcunet}
58
+ ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
59
+ RandomPedalboardChorus:
60
+ model: ${dcunet}
61
+ ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
62
+ RandomPedalboardDelay:
63
+ model: ${dcunet}
64
+ ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
65
  inference_effects_ordering:
66
  - "RandomPedalboardDistortion"
67
  - "RandomPedalboardCompressor"
remfx/callbacks.py CHANGED
@@ -4,6 +4,9 @@ from einops import rearrange
4
  import torch
5
  import wandb
6
  from torch import Tensor
 
 
 
7
 
8
 
9
  class AudioCallback(Callback):
@@ -42,7 +45,7 @@ class AudioCallback(Callback):
42
  def on_validation_batch_start(
43
  self, trainer, pl_module, batch, batch_idx, dataloader_idx
44
  ):
45
- x, target, _, _ = batch
46
  # Only run on first batch
47
  if batch_idx == 0 and self.log_audio:
48
  with torch.no_grad():
@@ -51,6 +54,19 @@ class AudioCallback(Callback):
51
 
52
  if type(pl_module) == RemFXChainInference:
53
  y = pl_module.sample(batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  else:
55
  y = pl_module.model.sample(x)
56
  # Concat samples together for easier viewing in dashboard
 
4
  import torch
5
  import wandb
6
  from torch import Tensor
7
+ from remfx import effects
8
+
9
+ ALL_EFFECTS = effects.Pedalboard_Effects
10
 
11
 
12
  class AudioCallback(Callback):
 
45
  def on_validation_batch_start(
46
  self, trainer, pl_module, batch, batch_idx, dataloader_idx
47
  ):
48
+ x, target, _, rem_fx_labels = batch
49
  # Only run on first batch
50
  if batch_idx == 0 and self.log_audio:
51
  with torch.no_grad():
 
54
 
55
  if type(pl_module) == RemFXChainInference:
56
  y = pl_module.sample(batch)
57
+ effects_present_name = [
58
+ [
59
+ ALL_EFFECTS[i].__name__.replace("RandomPedalboard", "")
60
+ for i, effect in enumerate(effect_label)
61
+ if effect == 1.0
62
+ ]
63
+ for effect_label in rem_fx_labels
64
+ ]
65
+ for i, label in enumerate(effects_present_name):
66
+ self.log(f"{'_'.join(label)}", 0.0)
67
+ # self.log(f"{effects}_{i}", label)
68
+ # trainer.logger.experiment.log(
69
+ # {f"effects_{i}": f"{'_'.join(label)}"}
70
  else:
71
  y = pl_module.model.sample(x)
72
  # Concat samples together for easier viewing in dashboard
scripts/chain_inference.py CHANGED
@@ -18,8 +18,8 @@ def main(cfg: DictConfig):
18
  log.info(f"Instantiating model <{cfg.model._target_}>.")
19
  models = {}
20
  for effect in cfg.ckpts:
21
- ckpt_path = cfg.ckpts[effect]
22
- model = hydra.utils.instantiate(cfg.model, _convert_="partial")
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
25
  model.load_state_dict(state_dict)
 
18
  log.info(f"Instantiating model <{cfg.model._target_}>.")
19
  models = {}
20
  for effect in cfg.ckpts:
21
+ model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
22
+ ckpt_path = cfg.ckpts[effect].ckpt_path
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
25
  model.load_state_dict(state_dict)
scripts/train.py CHANGED
@@ -18,7 +18,16 @@ def main(cfg: DictConfig):
18
 
19
  if "ckpt_path" in cfg:
20
  log.info(f"Loading checkpoint from <{cfg.ckpt_path}>.")
21
- model = model.load_from_checkpoint(cfg.ckpt_path)
 
 
 
 
 
 
 
 
 
22
 
23
  # Init all callbacks
24
  callbacks = []
 
18
 
19
  if "ckpt_path" in cfg:
20
  log.info(f"Loading checkpoint from <{cfg.ckpt_path}>.")
21
+ model.load_from_checkpoint(
22
+ cfg.ckpt_path,
23
+ lr=model.lr,
24
+ lr_beta1=model.lr_beta1,
25
+ lr_beta2=model.lr_beta2,
26
+ lr_eps=model.lr_eps,
27
+ lr_weight_decay=model.lr_weight_decay,
28
+ sample_rate=model.sample_rate,
29
+ network=model.model,
30
+ )
31
 
32
  # Init all callbacks
33
  callbacks = []