mattricesound commited on
Commit
6448f47
·
1 Parent(s): e543fe8

Fix new dataset to work for remfx training

Browse files
Files changed (3) hide show
  1. README.md +1 -2
  2. remfx/datasets.py +0 -4
  3. remfx/models.py +8 -32
README.md CHANGED
@@ -9,10 +9,9 @@
9
  5. `pip install -e umx`
10
 
11
  ## Download [VocalSet Dataset](https://zenodo.org/record/1193957)
12
- 1. `wget https://zenodo.org/record/1193957/files/VocalSet.zip?download=1`
13
  2. `mv VocalSet.zip?download=1 VocalSet.zip`
14
  3. `unzip VocalSet.zip`
15
- 4. Manually split singers into train, val, test directories
16
 
17
  # Training
18
  ## Steps
 
9
  5. `pip install -e umx`
10
 
11
  ## Download [VocalSet Dataset](https://zenodo.org/record/1193957)
12
+ 1. `wget https://zenodo.org/record/1442513/files/VocalSet1-2.zip?download=1`
13
  2. `mv VocalSet.zip?download=1 VocalSet.zip`
14
  3. `unzip VocalSet.zip`
 
15
 
16
  # Training
17
  ## Steps
remfx/datasets.py CHANGED
@@ -19,7 +19,6 @@ from remfx.utils import create_sequential_chunks
19
  # https://zenodo.org/record/1193957 -> VocalSet
20
 
21
  ALL_EFFECTS = effects.Pedalboard_Effects
22
- print(ALL_EFFECTS)
23
 
24
 
25
  singer_splits = {
@@ -206,7 +205,6 @@ class VocalSet(Dataset):
206
  else:
207
  num_kept_effects = len(self.effects_to_keep)
208
  effect_indices = effect_indices[:num_kept_effects]
209
- print(effect_indices)
210
 
211
  # Index in effect settings
212
  effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
@@ -249,8 +247,6 @@ class VocalSet(Dataset):
249
  for label_idx in dry_labels:
250
  dry_labels_tensor[label_idx] = 1.0
251
 
252
- # effects_present = torch.sum(one_hot, dim=0).float()
253
- print(dry_labels_tensor, wet_labels_tensor)
254
  # Normalize
255
  normalized_dry = self.normalize(dry)
256
  normalized_wet = self.normalize(wet)
 
19
  # https://zenodo.org/record/1193957 -> VocalSet
20
 
21
  ALL_EFFECTS = effects.Pedalboard_Effects
 
22
 
23
 
24
  singer_splits = {
 
205
  else:
206
  num_kept_effects = len(self.effects_to_keep)
207
  effect_indices = effect_indices[:num_kept_effects]
 
208
 
209
  # Index in effect settings
210
  effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
 
247
  for label_idx in dry_labels:
248
  dry_labels_tensor[label_idx] = 1.0
249
 
 
 
250
  # Normalize
251
  normalized_dry = self.normalize(dry)
252
  normalized_wet = self.normalize(wet)
remfx/models.py CHANGED
@@ -94,9 +94,9 @@ class RemFXModel(pl.LightningModule):
94
  return loss
95
 
96
  def common_step(self, batch, batch_idx, mode: str = "train"):
97
- loss, output = self.model(batch)
 
98
  self.log(f"{mode}_loss", loss)
99
- x, y, label = batch
100
  # Metric logging
101
  with torch.no_grad():
102
  for metric in self.metrics:
@@ -123,7 +123,7 @@ class RemFXModel(pl.LightningModule):
123
  def on_train_batch_start(self, batch, batch_idx):
124
  # Log initial audio
125
  if self.log_train_audio:
126
- x, y, label = batch
127
  # Concat samples together for easier viewing in dashboard
128
  input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
129
  target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
@@ -145,7 +145,7 @@ class RemFXModel(pl.LightningModule):
145
  self.log_train_audio = False
146
 
147
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
148
- x, target, label = batch
149
  # Log Input Metrics
150
  for metric in self.metrics:
151
  # SISDR returns negative values, so negate them
@@ -189,7 +189,7 @@ class RemFXModel(pl.LightningModule):
189
  def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
190
  self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
191
  # Log FAD
192
- x, target, label = batch
193
  self.log(
194
  "Input_FAD",
195
  self.metrics["FAD"](x, target),
@@ -237,7 +237,7 @@ class OpenUnmixModel(torch.nn.Module):
237
  self.l1loss = torch.nn.L1Loss()
238
 
239
  def forward(self, batch):
240
- x, target, label = batch
241
  X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
242
  Y = self.model(X)
243
  sep_out = self.separator(x).squeeze(1)
@@ -260,7 +260,7 @@ class DemucsModel(torch.nn.Module):
260
  self.l1loss = torch.nn.L1Loss()
261
 
262
  def forward(self, batch):
263
- x, target, label = batch
264
  output = self.model(x).squeeze(1)
265
  loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
266
  return loss, output
@@ -275,7 +275,7 @@ class DiffusionGenerationModel(nn.Module):
275
  self.model = DiffusionModel(in_channels=n_channels)
276
 
277
  def forward(self, batch):
278
- x, target, label = batch
279
  sampled_out = self.model.sample(x)
280
  return self.model(x), sampled_out
281
 
@@ -481,30 +481,6 @@ class Cnn14(nn.Module):
481
  return clipwise_output
482
 
483
 
484
- def spectrogram(
485
- x: torch.Tensor,
486
- window: torch.Tensor,
487
- n_fft: int,
488
- hop_length: int,
489
- alpha: float,
490
- ) -> torch.Tensor:
491
- bs, chs, samp = x.size()
492
- x = x.view(bs * chs, -1) # move channels onto batch dim
493
-
494
- X = torch.stft(
495
- x,
496
- n_fft=n_fft,
497
- hop_length=hop_length,
498
- window=window,
499
- return_complex=True,
500
- )
501
-
502
- # move channels back
503
- X = X.view(bs, chs, X.shape[-2], X.shape[-1])
504
-
505
- return torch.pow(X.abs() + 1e-8, alpha)
506
-
507
-
508
  class FXClassifier(pl.LightningModule):
509
  def __init__(
510
  self,
 
94
  return loss
95
 
96
  def common_step(self, batch, batch_idx, mode: str = "train"):
97
+ x, y, _, _ = batch
98
+ loss, output = self.model((x, y))
99
  self.log(f"{mode}_loss", loss)
 
100
  # Metric logging
101
  with torch.no_grad():
102
  for metric in self.metrics:
 
123
  def on_train_batch_start(self, batch, batch_idx):
124
  # Log initial audio
125
  if self.log_train_audio:
126
+ x, y, _, _ = batch
127
  # Concat samples together for easier viewing in dashboard
128
  input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
129
  target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
 
145
  self.log_train_audio = False
146
 
147
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
148
+ x, target, _, _ = batch
149
  # Log Input Metrics
150
  for metric in self.metrics:
151
  # SISDR returns negative values, so negate them
 
189
  def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
190
  self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
191
  # Log FAD
192
+ x, target, _, _ = batch
193
  self.log(
194
  "Input_FAD",
195
  self.metrics["FAD"](x, target),
 
237
  self.l1loss = torch.nn.L1Loss()
238
 
239
  def forward(self, batch):
240
+ x, target = batch
241
  X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
242
  Y = self.model(X)
243
  sep_out = self.separator(x).squeeze(1)
 
260
  self.l1loss = torch.nn.L1Loss()
261
 
262
  def forward(self, batch):
263
+ x, target = batch
264
  output = self.model(x).squeeze(1)
265
  loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
266
  return loss, output
 
275
  self.model = DiffusionModel(in_channels=n_channels)
276
 
277
  def forward(self, batch):
278
+ x, target = batch
279
  sampled_out = self.model.sample(x)
280
  return self.model(x), sampled_out
281
 
 
481
  return clipwise_output
482
 
483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  class FXClassifier(pl.LightningModule):
485
  def __init__(
486
  self,