mattricesound commited on
Commit
7902946
1 Parent(s): 133e1dc

Fix remfx all effects selection

Browse files
Files changed (1) hide show
  1. remfx/models.py +6 -5
remfx/models.py CHANGED
@@ -67,7 +67,7 @@ class RemFXChainInference(pl.LightningModule):
67
  rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
68
  if self.use_all_effect_models:
69
  effects_present = [
70
- [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect]
71
  for effect_label in rem_fx_labels
72
  ]
73
  else:
@@ -79,6 +79,7 @@ class RemFXChainInference(pl.LightningModule):
79
  ]
80
  for effect_label in rem_fx_labels
81
  ]
 
82
  output = []
83
  # input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
84
  # target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
@@ -179,8 +180,8 @@ class RemFXChainInference(pl.LightningModule):
179
  prog_bar=True,
180
  sync_dist=True,
181
  )
182
- print(f"Input_{metric}", negate * self.metrics[metric](x, y))
183
- print(f"test_{metric}", negate * self.metrics[metric](output, y))
184
  self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
185
  self.output_str += "\n"
186
  return loss
@@ -297,8 +298,8 @@ class RemFX(pl.LightningModule):
297
  prog_bar=True,
298
  sync_dist=True,
299
  )
300
- print(f"Input_{metric}", negate * self.metrics[metric](x, y))
301
- print(f"test_{metric}", negate * self.metrics[metric](output, y))
302
  self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
303
  self.output_str += "\n"
304
  return loss
 
67
  rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
68
  if self.use_all_effect_models:
69
  effects_present = [
70
+ [ALL_EFFECTS[i] for i, effect in enumerate(effect_label)]
71
  for effect_label in rem_fx_labels
72
  ]
73
  else:
 
79
  ]
80
  for effect_label in rem_fx_labels
81
  ]
82
+
83
  output = []
84
  # input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
85
  # target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
 
180
  prog_bar=True,
181
  sync_dist=True,
182
  )
183
+ # print(f"Input_{metric}", negate * self.metrics[metric](x, y))
184
+ # print(f"test_{metric}", negate * self.metrics[metric](output, y))
185
  self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
186
  self.output_str += "\n"
187
  return loss
 
298
  prog_bar=True,
299
  sync_dist=True,
300
  )
301
+ # print(f"Input_{metric}", negate * self.metrics[metric](x, y))
302
+ # print(f"test_{metric}", negate * self.metrics[metric](output, y))
303
  self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
304
  self.output_str += "\n"
305
  return loss