mattricesound commited on
Commit
7fc4de1
·
1 Parent(s): b8427f9

Update datagen silence threshold to 1e-4

Browse files
cfg/exp/chain_inference_aug_classifier.yaml CHANGED
@@ -47,14 +47,15 @@ classifier:
47
  lr: 3e-4
48
  lr_weight_decay: 1e-3
49
  sample_rate: ${sample_rate}
 
50
  network:
51
  _target_: remfx.classifier.Cnn14
52
  num_classes: ${num_classes}
53
- n_fft: 1024
54
- hop_length: 256
55
  n_mels: 128
56
- sample_rate: 44100
57
- model_sample_rate: 44100
58
  specaugment: False
59
  classifier_ckpt: "ckpts/classifier.ckpt"
60
 
 
47
  lr: 3e-4
48
  lr_weight_decay: 1e-3
49
  sample_rate: ${sample_rate}
50
+ mixup: False
51
  network:
52
  _target_: remfx.classifier.Cnn14
53
  num_classes: ${num_classes}
54
+ n_fft: 2048
55
+ hop_length: 512
56
  n_mels: 128
57
+ sample_rate: ${sample_rate}
58
+ model_sample_rate: ${sample_rate}
59
  specaugment: False
60
  classifier_ckpt: "ckpts/classifier.ckpt"
61
 
remfx/datasets.py CHANGED
@@ -259,7 +259,7 @@ class EffectDataset(Dataset):
259
  render_files: bool = True,
260
  render_root: str = None,
261
  mode: str = "train",
262
- parallel: bool = True,
263
  ):
264
  super().__init__()
265
  self.chunks = []
@@ -342,7 +342,6 @@ class EffectDataset(Dataset):
342
  chunk = select_random_chunk(
343
  random_file_choice, self.chunk_size, self.sample_rate
344
  )
345
-
346
  # Sum to mono
347
  if chunk.shape[0] > 1:
348
  chunk = chunk.sum(0, keepdim=True)
@@ -561,7 +560,7 @@ class EffectDatamodule(pl.LightningDataModule):
561
  def test_dataloader(self) -> DataLoader:
562
  return DataLoader(
563
  dataset=self.test_dataset,
564
- batch_size=2, # Use small, consistent batch size for testing
565
  num_workers=self.num_workers,
566
  pin_memory=self.pin_memory,
567
  shuffle=False,
 
259
  render_files: bool = True,
260
  render_root: str = None,
261
  mode: str = "train",
262
+ parallel: bool = False,
263
  ):
264
  super().__init__()
265
  self.chunks = []
 
342
  chunk = select_random_chunk(
343
  random_file_choice, self.chunk_size, self.sample_rate
344
  )
 
345
  # Sum to mono
346
  if chunk.shape[0] > 1:
347
  chunk = chunk.sum(0, keepdim=True)
 
560
  def test_dataloader(self) -> DataLoader:
561
  return DataLoader(
562
  dataset=self.test_dataset,
563
+ batch_size=1, # Use small, consistent batch size for testing
564
  num_workers=self.num_workers,
565
  pin_memory=self.pin_memory,
566
  shuffle=False,
remfx/models.py CHANGED
@@ -37,7 +37,7 @@ class RemFXChainInference(pl.LightningModule):
37
  self.sample_rate = sample_rate
38
  self.effect_order = effect_order
39
  self.classifier = classifier
40
- # self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
41
 
42
  def forward(self, batch, batch_idx, order=None):
43
  x, y, _, rem_fx_labels = batch
@@ -46,7 +46,7 @@ class RemFXChainInference(pl.LightningModule):
46
  effects_order = order
47
  else:
48
  effects_order = self.effect_order
49
-
50
  # Use classifier labels
51
  if self.classifier:
52
  threshold = 0.5
@@ -113,13 +113,13 @@ class RemFXChainInference(pl.LightningModule):
113
  output = torch.stack(output)
114
  output_samples = rearrange(output, "b c t -> c (b t)").unsqueeze(0)
115
 
116
- log_wandb_audio_batch(
117
- logger=self.logger,
118
- id="output_audio",
119
- samples=output_samples.cpu(),
120
- sampling_rate=self.sample_rate,
121
- caption="Output Data",
122
- )
123
  loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
124
  return loss, output
125
 
@@ -158,13 +158,16 @@ class RemFXChainInference(pl.LightningModule):
158
  prog_bar=True,
159
  sync_dist=True,
160
  )
161
- # self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
162
- # self.output_str += "\n"
163
- # if batch_idx == 4:
164
- # with open("output.csv", "w") as f:
165
- # f.write(self.output_str)
166
  return loss
167
 
 
 
 
 
168
  def sample(self, batch):
169
  return self.forward(batch, 0)[1]
170
 
@@ -196,6 +199,7 @@ class RemFX(pl.LightningModule):
196
  )
197
  # Log first batch metrics input vs output only once
198
  self.log_train_audio = True
 
199
 
200
  @property
201
  def device(self):
@@ -272,9 +276,16 @@ class RemFX(pl.LightningModule):
272
  prog_bar=True,
273
  sync_dist=True,
274
  )
275
-
 
 
 
276
  return loss
277
 
 
 
 
 
278
 
279
  class OpenUnmixModel(nn.Module):
280
  def __init__(
 
37
  self.sample_rate = sample_rate
38
  self.effect_order = effect_order
39
  self.classifier = classifier
40
+ self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
41
 
42
  def forward(self, batch, batch_idx, order=None):
43
  x, y, _, rem_fx_labels = batch
 
46
  effects_order = order
47
  else:
48
  effects_order = self.effect_order
49
+ old_labels = rem_fx_labels
50
  # Use classifier labels
51
  if self.classifier:
52
  threshold = 0.5
 
113
  output = torch.stack(output)
114
  output_samples = rearrange(output, "b c t -> c (b t)").unsqueeze(0)
115
 
116
+ # log_wandb_audio_batch(
117
+ # logger=self.logger,
118
+ # id="output_audio",
119
+ # samples=output_samples.cpu(),
120
+ # sampling_rate=self.sample_rate,
121
+ # caption="Output Data",
122
+ # )
123
  loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
124
  return loss, output
125
 
 
158
  prog_bar=True,
159
  sync_dist=True,
160
  )
161
+ print(f"Input_{metric}", negate * self.metrics[metric](x, y))
162
+ print(f"test_{metric}", negate * self.metrics[metric](output, y))
163
+ self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
164
+ self.output_str += "\n"
 
165
  return loss
166
 
167
+ def on_test_end(self) -> None:
168
+ with open("output.csv", "w") as f:
169
+ f.write(self.output_str)
170
+
171
  def sample(self, batch):
172
  return self.forward(batch, 0)[1]
173
 
 
199
  )
200
  # Log first batch metrics input vs output only once
201
  self.log_train_audio = True
202
+ self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
203
 
204
  @property
205
  def device(self):
 
276
  prog_bar=True,
277
  sync_dist=True,
278
  )
279
+ print(f"Input_{metric}", negate * self.metrics[metric](x, y))
280
+ print(f"test_{metric}", negate * self.metrics[metric](output, y))
281
+ self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
282
+ self.output_str += "\n"
283
  return loss
284
 
285
+ def on_test_end(self) -> None:
286
+ with open("output.csv", "w") as f:
287
+ f.write(self.output_str)
288
+
289
 
290
  class OpenUnmixModel(nn.Module):
291
  def __init__(
remfx/utils.py CHANGED
@@ -159,7 +159,7 @@ def select_random_chunk(
159
  random_start = torch.randint(0, max_len, (1,)).item()
160
  chunk = audio[:, random_start : random_start + new_chunk_size]
161
  # Skip if energy too low
162
- if torch.mean(torch.abs(chunk)) < 1e-6:
163
  return None
164
  resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
165
  return resampled_chunk
 
159
  random_start = torch.randint(0, max_len, (1,)).item()
160
  chunk = audio[:, random_start : random_start + new_chunk_size]
161
  # Skip if energy too low
162
+ if torch.mean(torch.abs(chunk)) < 1e-4:
163
  return None
164
  resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
165
  return resampled_chunk