mattricesound commited on
Commit
f13cb8e
1 Parent(s): 3e2f073

Fix input metrics

Browse files
cfg/model/demucs.yaml CHANGED
@@ -13,5 +13,5 @@ model:
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
16
- channels: 64
17
 
 
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
16
+ channels: 48
17
 
remfx/callbacks.py CHANGED
@@ -71,48 +71,6 @@ class AudioCallback(Callback):
71
  self.on_validation_batch_start(*args)
72
 
73
 
74
- class MetricCallback(Callback):
75
- def on_validation_batch_start(
76
- self, trainer, pl_module, batch, batch_idx, dataloader_idx
77
- ):
78
- x, target, _, _ = batch
79
- # Log Input Metrics
80
- for metric in pl_module.metrics:
81
- # SISDR returns negative values, so negate them
82
- if metric == "SISDR":
83
- negate = -1
84
- else:
85
- negate = 1
86
- # Only Log FAD on test set
87
- if metric == "FAD":
88
- continue
89
- pl_module.log(
90
- f"Input_{metric}",
91
- negate * pl_module.metrics[metric](x, target),
92
- on_step=False,
93
- on_epoch=True,
94
- logger=True,
95
- prog_bar=True,
96
- sync_dist=True,
97
- )
98
-
99
- def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
100
- self.on_validation_batch_start(
101
- trainer, pl_module, batch, batch_idx, dataloader_idx
102
- )
103
- # Log FAD
104
- x, target, _, _ = batch
105
- pl_module.log(
106
- "Input_FAD",
107
- pl_module.metrics["FAD"](x, target),
108
- on_step=False,
109
- on_epoch=True,
110
- logger=True,
111
- prog_bar=True,
112
- sync_dist=True,
113
- )
114
-
115
-
116
  def log_wandb_audio_batch(
117
  logger: pl.loggers.WandbLogger,
118
  id: str,
 
71
  self.on_validation_batch_start(*args)
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def log_wandb_audio_batch(
75
  logger: pl.loggers.WandbLogger,
76
  id: str,
remfx/models.py CHANGED
@@ -14,7 +14,6 @@ from remfx.utils import causal_crop
14
  from remfx.callbacks import log_wandb_audio_batch
15
  from remfx import effects
16
  import asteroid
17
- import random
18
 
19
  ALL_EFFECTS = effects.Pedalboard_Effects
20
 
@@ -52,31 +51,36 @@ class RemFXChainInference(pl.LightningModule):
52
  with torch.no_grad():
53
  for i, (elem, effects_list) in enumerate(zip(x, effects)):
54
  elem = elem.unsqueeze(0) # Add batch dim
55
- # effect_chain_idx = [
56
- # effects_order.index(effect.__name__) for effect in effects_list
57
- # ]
58
  effect_list_names = [effect.__name__ for effect in effects_list]
59
  effects = [
60
  effect for effect in effects_order if effect in effect_list_names
61
  ]
62
 
63
- # log_wandb_audio_batch(
64
- # logger=self.logger,
65
- # id=f"{batch_idx}_{i}_Before",
66
- # samples=elem.cpu(),
67
- # sampling_rate=self.sample_rate,
68
- # caption=effects,
69
- # )
70
  for effect in effects:
71
  # Sample the model
72
  elem = self.model[effect].model.sample(elem)
73
- # log_wandb_audio_batch(
74
- # logger=self.logger,
75
- # id=f"{batch_idx}_{i}_{effect}",
76
- # samples=elem.cpu(),
77
- # sampling_rate=self.sample_rate,
78
- # caption=effects,
79
- # )
 
 
 
 
 
 
 
80
  output.append(elem.squeeze(0))
81
  output = torch.stack(output)
82
 
@@ -111,7 +115,7 @@ class RemFXChainInference(pl.LightningModule):
111
  )
112
 
113
  def sample(self, batch):
114
- return self.forward(batch)[1]
115
 
116
 
117
  class RemFX(pl.LightningModule):
@@ -207,6 +211,15 @@ class RemFX(pl.LightningModule):
207
  prog_bar=True,
208
  sync_dist=True,
209
  )
 
 
 
 
 
 
 
 
 
210
  return loss
211
 
212
 
 
14
  from remfx.callbacks import log_wandb_audio_batch
15
  from remfx import effects
16
  import asteroid
 
17
 
18
  ALL_EFFECTS = effects.Pedalboard_Effects
19
 
 
51
  with torch.no_grad():
52
  for i, (elem, effects_list) in enumerate(zip(x, effects)):
53
  elem = elem.unsqueeze(0) # Add batch dim
54
+ # Get the correct effect by search for names in effects_order
 
 
55
  effect_list_names = [effect.__name__ for effect in effects_list]
56
  effects = [
57
  effect for effect in effects_order if effect in effect_list_names
58
  ]
59
 
60
+ log_wandb_audio_batch(
61
+ logger=self.logger,
62
+ id=f"{i}_Before",
63
+ samples=elem.cpu(),
64
+ sampling_rate=self.sample_rate,
65
+ caption=effects,
66
+ )
67
  for effect in effects:
68
  # Sample the model
69
  elem = self.model[effect].model.sample(elem)
70
+ log_wandb_audio_batch(
71
+ logger=self.logger,
72
+ id=f"{i}_{effect}",
73
+ samples=elem.cpu(),
74
+ sampling_rate=self.sample_rate,
75
+ caption=effects,
76
+ )
77
+ log_wandb_audio_batch(
78
+ logger=self.logger,
79
+ id=f"{i}_After",
80
+ samples=elem.cpu(),
81
+ sampling_rate=self.sample_rate,
82
+ caption=effects,
83
+ )
84
  output.append(elem.squeeze(0))
85
  output = torch.stack(output)
86
 
 
115
  )
116
 
117
  def sample(self, batch):
118
+ return self.forward(batch, 0)[1]
119
 
120
 
121
  class RemFX(pl.LightningModule):
 
211
  prog_bar=True,
212
  sync_dist=True,
213
  )
214
+ self.log(
215
+ f"Input_{metric}",
216
+ negate * self.metrics[metric](x, y),
217
+ on_step=False,
218
+ on_epoch=True,
219
+ logger=True,
220
+ prog_bar=True,
221
+ sync_dist=True,
222
+ )
223
  return loss
224
 
225
 
scripts/chain_inference.py CHANGED
@@ -20,9 +20,10 @@ def main(cfg: DictConfig):
20
  for effect in cfg.ckpts:
21
  ckpt_path = cfg.ckpts[effect]
22
  model = hydra.utils.instantiate(cfg.model, _convert_="partial")
23
- state_dict = torch.load(ckpt_path)["state_dict"]
 
24
  model.load_state_dict(state_dict)
25
- model.to("cuda" if torch.cuda.is_available() else "cpu")
26
  models[effect] = model
27
 
28
  callbacks = []
 
20
  for effect in cfg.ckpts:
21
  ckpt_path = cfg.ckpts[effect]
22
  model = hydra.utils.instantiate(cfg.model, _convert_="partial")
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
25
  model.load_state_dict(state_dict)
26
+ model.to(device)
27
  models[effect] = model
28
 
29
  callbacks = []