mattricesound commited on
Commit
133e1dc
1 Parent(s): 7173f20

Add shuffling effect order, all effects present for chain_inference to cfg

Browse files
cfg/exp/chain_inference.yaml CHANGED
@@ -63,4 +63,6 @@ inference_effects_ordering:
63
  - "RandomPedalboardReverb"
64
  - "RandomPedalboardChorus"
65
  - "RandomPedalboardDelay"
66
- num_bins: 1025
 
 
 
63
  - "RandomPedalboardReverb"
64
  - "RandomPedalboardChorus"
65
  - "RandomPedalboardDelay"
66
+ num_bins: 1025
67
+ inference_effects_shuffle: False
68
+ inference_use_all_effect_models: False
cfg/exp/chain_inference_aug.yaml CHANGED
@@ -63,4 +63,6 @@ inference_effects_ordering:
63
  - "RandomPedalboardReverb"
64
  - "RandomPedalboardChorus"
65
  - "RandomPedalboardDelay"
66
- num_bins: 1025
 
 
 
63
  - "RandomPedalboardReverb"
64
  - "RandomPedalboardChorus"
65
  - "RandomPedalboardDelay"
66
+ num_bins: 1025
67
+ inference_effects_shuffle: False
68
+ inference_use_all_effect_models: False
cfg/exp/chain_inference_aug_classifier.yaml CHANGED
@@ -82,4 +82,6 @@ inference_effects_ordering:
82
  - "RandomPedalboardReverb"
83
  - "RandomPedalboardChorus"
84
  - "RandomPedalboardDelay"
85
- num_bins: 1025
 
 
 
82
  - "RandomPedalboardReverb"
83
  - "RandomPedalboardChorus"
84
  - "RandomPedalboardDelay"
85
+ num_bins: 1025
86
+ inference_effects_shuffle: False
87
+ inference_use_all_effect_models: False
cfg/exp/chain_inference_custom.yaml CHANGED
@@ -68,4 +68,6 @@ inference_effects_ordering:
68
  - "RandomPedalboardReverb"
69
  - "RandomPedalboardChorus"
70
  - "RandomPedalboardDelay"
71
- num_bins: 1025
 
 
 
68
  - "RandomPedalboardReverb"
69
  - "RandomPedalboardChorus"
70
  - "RandomPedalboardDelay"
71
+ num_bins: 1025
72
+ inference_effects_shuffle: False
73
+ inference_use_all_effect_models: False
remfx/models.py CHANGED
@@ -16,12 +16,22 @@ from remfx.callbacks import log_wandb_audio_batch
16
  from einops import rearrange
17
  from remfx import effects
18
  import asteroid
 
19
 
20
  ALL_EFFECTS = effects.Pedalboard_Effects
21
 
22
 
23
  class RemFXChainInference(pl.LightningModule):
24
- def __init__(self, models, sample_rate, num_bins, effect_order, classifier=None):
 
 
 
 
 
 
 
 
 
25
  super().__init__()
26
  self.model = models
27
  self.mrstftloss = MultiResolutionSTFTLoss(
@@ -37,7 +47,9 @@ class RemFXChainInference(pl.LightningModule):
37
  self.sample_rate = sample_rate
38
  self.effect_order = effect_order
39
  self.classifier = classifier
 
40
  self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
 
41
 
42
  def forward(self, batch, batch_idx, order=None):
43
  x, y, _, rem_fx_labels = batch
@@ -46,36 +58,45 @@ class RemFXChainInference(pl.LightningModule):
46
  effects_order = order
47
  else:
48
  effects_order = self.effect_order
49
- old_labels = rem_fx_labels
50
  # Use classifier labels
51
  if self.classifier:
52
  threshold = 0.5
53
  with torch.no_grad():
54
  labels = torch.sigmoid(self.classifier(x))
55
  rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
56
-
57
- effects_present = [
58
- [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
59
- for effect_label in rem_fx_labels
60
- ]
 
 
 
 
 
 
 
 
 
61
  output = []
62
- input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
63
- target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
64
-
65
- log_wandb_audio_batch(
66
- logger=self.logger,
67
- id="input_effected_audio",
68
- samples=input_samples.cpu(),
69
- sampling_rate=self.sample_rate,
70
- caption="Input Data",
71
- )
72
- log_wandb_audio_batch(
73
- logger=self.logger,
74
- id="target_audio",
75
- samples=target_samples.cpu(),
76
- sampling_rate=self.sample_rate,
77
- caption="Target Data",
78
- )
79
  with torch.no_grad():
80
  for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
81
  elem = elem.unsqueeze(0) # Add batch dim
@@ -111,7 +132,6 @@ class RemFXChainInference(pl.LightningModule):
111
  # )
112
  output.append(elem.squeeze(0))
113
  output = torch.stack(output)
114
- output_samples = rearrange(output, "b c t -> c (b t)").unsqueeze(0)
115
 
116
  # log_wandb_audio_batch(
117
  # logger=self.logger,
@@ -125,8 +145,9 @@ class RemFXChainInference(pl.LightningModule):
125
 
126
  def test_step(self, batch, batch_idx):
127
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
128
- # Random order
129
- # random.shuffle(self.effect_order)
 
130
  loss, output = self.forward(batch, batch_idx, order=self.effect_order)
131
  # Crop target to match output
132
  if output.shape[-1] < y.shape[-1]:
 
16
  from einops import rearrange
17
  from remfx import effects
18
  import asteroid
19
+ import random
20
 
21
  ALL_EFFECTS = effects.Pedalboard_Effects
22
 
23
 
24
  class RemFXChainInference(pl.LightningModule):
25
+ def __init__(
26
+ self,
27
+ models,
28
+ sample_rate,
29
+ num_bins,
30
+ effect_order,
31
+ classifier=None,
32
+ shuffle_effect_order=False,
33
+ use_all_effect_models=False,
34
+ ):
35
  super().__init__()
36
  self.model = models
37
  self.mrstftloss = MultiResolutionSTFTLoss(
 
47
  self.sample_rate = sample_rate
48
  self.effect_order = effect_order
49
  self.classifier = classifier
50
+ self.shuffle_effect_order = shuffle_effect_order
51
  self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
52
+ self.use_all_effect_models = use_all_effect_models
53
 
54
  def forward(self, batch, batch_idx, order=None):
55
  x, y, _, rem_fx_labels = batch
 
58
  effects_order = order
59
  else:
60
  effects_order = self.effect_order
61
+
62
  # Use classifier labels
63
  if self.classifier:
64
  threshold = 0.5
65
  with torch.no_grad():
66
  labels = torch.sigmoid(self.classifier(x))
67
  rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
68
+ if self.use_all_effect_models:
69
+ effects_present = [
70
+ [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect]
71
+ for effect_label in rem_fx_labels
72
+ ]
73
+ else:
74
+ effects_present = [
75
+ [
76
+ ALL_EFFECTS[i]
77
+ for i, effect in enumerate(effect_label)
78
+ if effect == 1.0
79
+ ]
80
+ for effect_label in rem_fx_labels
81
+ ]
82
  output = []
83
+ # input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
84
+ # target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
85
+
86
+ # log_wandb_audio_batch(
87
+ # logger=self.logger,
88
+ # id="input_effected_audio",
89
+ # samples=input_samples.cpu(),
90
+ # sampling_rate=self.sample_rate,
91
+ # caption="Input Data",
92
+ # )
93
+ # log_wandb_audio_batch(
94
+ # logger=self.logger,
95
+ # id="target_audio",
96
+ # samples=target_samples.cpu(),
97
+ # sampling_rate=self.sample_rate,
98
+ # caption="Target Data",
99
+ # )
100
  with torch.no_grad():
101
  for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
102
  elem = elem.unsqueeze(0) # Add batch dim
 
132
  # )
133
  output.append(elem.squeeze(0))
134
  output = torch.stack(output)
 
135
 
136
  # log_wandb_audio_batch(
137
  # logger=self.logger,
 
145
 
146
  def test_step(self, batch, batch_idx):
147
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
148
+ if self.shuffle_effect_order:
149
+ # Random order
150
+ random.shuffle(self.effect_order)
151
  loss, output = self.forward(batch, batch_idx, order=self.effect_order)
152
  # Crop target to match output
153
  if output.shape[-1] < y.shape[-1]:
scripts/chain_inference.py CHANGED
@@ -65,6 +65,8 @@ def main(cfg: DictConfig):
65
  num_bins=cfg.num_bins,
66
  effect_order=cfg.inference_effects_ordering,
67
  classifier=classifier,
 
 
68
  )
69
  trainer.test(model=inference_model, datamodule=datamodule)
70
 
 
65
  num_bins=cfg.num_bins,
66
  effect_order=cfg.inference_effects_ordering,
67
  classifier=classifier,
68
+ shuffle_effect_order=cfg.inference_effects_shuffle,
69
+ use_all_effect_models=cfg.inference_use_all_effect_models,
70
  )
71
  trainer.test(model=inference_model, datamodule=datamodule)
72