mattricesound commited on
Commit
8125531
1 Parent(s): 7ed6389

Clean project. Add 'all effects' to experiments

Browse files
README.md CHANGED
@@ -13,7 +13,7 @@
13
  4. Manually split singers into train, val, test directories
14
 
15
  ## Train model
16
- 1. Change Wandb variables in `shell_vars.sh` and `source shell_vars.sh`
17
  2. `python scripts/train.py +exp=umx_distortion`
18
  or
19
  2. `python scripts/train.py +exp=demucs_distortion`
@@ -33,6 +33,12 @@ Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' train
33
  - `compressor`
34
  - `distortion`
35
  - `reverb`
 
36
 
37
  ## Misc.
38
- To skip rendering files, add `+datamodule.train_dataset.render_files=False +datamodule.val_dataset.render_files=False` to the command-line
 
 
 
 
 
 
13
  4. Manually split singers into train, val, test directories
14
 
15
  ## Train model
16
+ 1. Change Wandb and data root variables in `shell_vars.sh` and `source shell_vars.sh`
17
  2. `python scripts/train.py +exp=umx_distortion`
18
  or
19
  2. `python scripts/train.py +exp=demucs_distortion`
 
33
  - `compressor`
34
  - `distortion`
35
  - `reverb`
36
+ - `all` (choose random effect to apply to each file)
37
 
38
  ## Misc.
39
+ By default, files are rendered to `input_dir / processed / train/val/test`.
40
+ To skip rendering files (use previously rendered), add `render_files=False` to the command-line
41
+
42
+ Test
43
+ Experiment dictates data, ckpt dictates model
44
+ `python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
cfg/config.yaml CHANGED
@@ -8,6 +8,7 @@ train: True
8
  sample_rate: 48000
9
  logs_dir: "./logs"
10
  log_every_n_steps: 1000
 
11
 
12
  callbacks:
13
  model_checkpoint:
@@ -26,18 +27,27 @@ datamodule:
26
  _target_: remfx.datasets.VocalSet
27
  sample_rate: ${sample_rate}
28
  root: ${oc.env:DATASET_ROOT}
29
- output_root: ${oc.env:OUTPUT_ROOT}/train
30
  chunk_size_in_sec: 6
31
  mode: "train"
32
  effect_types: ${effects.train_effects}
 
33
  val_dataset:
34
  _target_: remfx.datasets.VocalSet
35
  sample_rate: ${sample_rate}
36
  root: ${oc.env:DATASET_ROOT}
37
- output_root: ${oc.env:OUTPUT_ROOT}/val
38
  chunk_size_in_sec: 6
39
  mode: "val"
40
  effect_types: ${effects.val_effects}
 
 
 
 
 
 
 
 
 
 
41
  batch_size: 16
42
  num_workers: 8
43
  pin_memory: True
 
8
  sample_rate: 48000
9
  logs_dir: "./logs"
10
  log_every_n_steps: 1000
11
+ render_files: True
12
 
13
  callbacks:
14
  model_checkpoint:
 
27
  _target_: remfx.datasets.VocalSet
28
  sample_rate: ${sample_rate}
29
  root: ${oc.env:DATASET_ROOT}
 
30
  chunk_size_in_sec: 6
31
  mode: "train"
32
  effect_types: ${effects.train_effects}
33
+ render_files: ${render_files}
34
  val_dataset:
35
  _target_: remfx.datasets.VocalSet
36
  sample_rate: ${sample_rate}
37
  root: ${oc.env:DATASET_ROOT}
 
38
  chunk_size_in_sec: 6
39
  mode: "val"
40
  effect_types: ${effects.val_effects}
41
+ render_files: ${render_files}
42
+ test_dataset:
43
+ _target_: remfx.datasets.VocalSet
44
+ sample_rate: ${sample_rate}
45
+ root: ${oc.env:DATASET_ROOT}
46
+ chunk_size_in_sec: 6
47
+ mode: "test"
48
+ effect_types: ${effects.val_effects}
49
+ render_files: ${render_files}
50
+
51
  batch_size: 16
52
  num_workers: 8
53
  pin_memory: True
cfg/config_guitarset.yaml DELETED
@@ -1,52 +0,0 @@
1
- defaults:
2
- - _self_
3
- - exp: null
4
- seed: 12345
5
- train: True
6
- sample_rate: 48000
7
- logs_dir: "./logs"
8
- log_every_n_steps: 1000
9
-
10
- callbacks:
11
- model_checkpoint:
12
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
13
- monitor: "valid_loss" # name of the logged metric which determines when model is improving
14
- save_top_k: 1 # save k best models (determined by above metric)
15
- save_last: True # additionaly always save model from last epoch
16
- mode: "min" # can be "max" or "min"
17
- verbose: False
18
- dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
19
- filename: '{epoch:02d}-{valid_loss:.3f}'
20
-
21
- datamodule:
22
- _target_: remfx.datasets.Datamodule
23
- dataset:
24
- _target_: remfx.datasets.GuitarSet
25
- sample_rate: ${sample_rate}
26
- root: ${oc.env:DATASET_ROOT}
27
- chunk_size_in_sec: 6
28
- val_split: 0.2
29
- batch_size: 16
30
- num_workers: 8
31
- pin_memory: True
32
- persistent_workers: True
33
-
34
- logger:
35
- _target_: pytorch_lightning.loggers.WandbLogger
36
- project: ${oc.env:WANDB_PROJECT}
37
- entity: ${oc.env:WANDB_ENTITY}
38
- # offline: False # set True to store all logs only locally
39
- job_type: "train"
40
- group: ""
41
- save_dir: "."
42
-
43
- trainer:
44
- _target_: pytorch_lightning.Trainer
45
- precision: 32 # Precision used for tensors, default `32`
46
- min_epochs: 0
47
- max_epochs: -1
48
- enable_model_summary: False
49
- log_every_n_steps: 1 # Logs metrics every N batches
50
- accumulate_grad_batches: 1
51
- accelerator: null
52
- devices: 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/config_guitfx.yaml DELETED
@@ -1,52 +0,0 @@
1
- defaults:
2
- - _self_
3
- - exp: null
4
- seed: 12345
5
- train: True
6
- sample_rate: 48000
7
- logs_dir: "./logs"
8
- log_every_n_steps: 1000
9
-
10
- callbacks:
11
- model_checkpoint:
12
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
13
- monitor: "valid_loss" # name of the logged metric which determines when model is improving
14
- save_top_k: 1 # save k best models (determined by above metric)
15
- save_last: True # additionaly always save model from last epoch
16
- mode: "min" # can be "max" or "min"
17
- verbose: False
18
- dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
19
- filename: '{epoch:02d}-{valid_loss:.3f}'
20
-
21
- datamodule:
22
- _target_: remfx.datasets.Datamodule
23
- dataset:
24
- _target_: remfx.datasets.GuitarFXDataset
25
- sample_rate: ${sample_rate}
26
- root: ${oc.env:DATASET_ROOT}
27
- chunk_size_in_sec: 6
28
- val_split: 0.2
29
- batch_size: 16
30
- num_workers: 8
31
- pin_memory: True
32
- persistent_workers: True
33
-
34
- logger:
35
- _target_: pytorch_lightning.loggers.WandbLogger
36
- project: ${oc.env:WANDB_PROJECT}
37
- entity: ${oc.env:WANDB_ENTITY}
38
- # offline: False # set True to store all logs only locally
39
- job_type: "train"
40
- group: ""
41
- save_dir: "."
42
-
43
- trainer:
44
- _target_: pytorch_lightning.Trainer
45
- precision: 32 # Precision used for tensors, default `32`
46
- min_epochs: 0
47
- max_epochs: -1
48
- enable_model_summary: False
49
- log_every_n_steps: 1 # Logs metrics every N batches
50
- accumulate_grad_batches: 1
51
- accelerator: null
52
- devices: 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/effects/all.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ effects:
3
+ train_effects:
4
+ Chorus:
5
+ _target_: remfx.effects.RandomPedalboardChorus
6
+ sample_rate: ${sample_rate}
7
+ Distortion:
8
+ _target_: remfx.effects.RandomPedalboardDistortion
9
+ sample_rate: ${sample_rate}
10
+ min_drive_db: -10
11
+ max_drive_db: 50
12
+ Compressor:
13
+ _target_: remfx.effects.RandomPedalboardCompressor
14
+ sample_rate: ${sample_rate}
15
+ min_threshold_db: -42.0
16
+ max_threshold_db: -20.0
17
+ min_ratio: 1.5
18
+ max_ratio: 6.0
19
+ Reverb:
20
+ _target_: remfx.effects.RandomPedalboardReverb
21
+ sample_rate: ${sample_rate}
22
+ min_room_size: 0.3
23
+ max_room_size: 1.0
24
+ min_damping: 0.2
25
+ max_damping: 1.0
26
+ min_wet_dry: 0.2
27
+ max_wet_dry: 0.8
28
+ min_width: 0.2
29
+ max_width: 1.0
30
+ val_effects:
31
+ Chorus:
32
+ _target_: remfx.effects.RandomPedalboardChorus
33
+ sample_rate: ${sample_rate}
34
+ min_rate_hz: 1.0
35
+ max_rate_hz: 1.0
36
+ min_depth: 0.3
37
+ max_depth: 0.3
38
+ min_centre_delay_ms: 7.5
39
+ max_centre_delay_ms: 7.5
40
+ min_feedback: 0.4
41
+ max_feedback: 0.4
42
+ min_mix: 0.4
43
+ max_mix: 0.4
44
+ Distortion:
45
+ _target_: remfx.effects.RandomPedalboardDistortion
46
+ sample_rate: ${sample_rate}
47
+ min_drive_db: 30
48
+ max_drive_db: 30
49
+ Compressor:
50
+ _target_: remfx.effects.RandomPedalboardCompressor
51
+ sample_rate: ${sample_rate}
52
+ min_threshold_db: -32
53
+ max_threshold_db: -32
54
+ min_ratio: 3.0
55
+ max_ratio: 3.0
56
+ min_attack_ms: 10.0
57
+ max_attack_ms: 10.0
58
+ min_release_ms: 40.0
59
+ max_release_ms: 40.0
60
+ Reverb:
61
+ _target_: remfx.effects.RandomPedalboardReverb
62
+ sample_rate: ${sample_rate}
63
+ min_room_size: 0.5
64
+ max_room_size: 0.5
65
+ min_damping: 0.5
66
+ max_damping: 0.5
67
+ min_wet_dry: 0.4
68
+ max_wet_dry: 0.4
69
+ min_width: 0.5
70
+ max_width: 0.5
cfg/exp/demucs_all.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
cfg/exp/umx_all.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: umx
4
+ - override /effects: all
remfx/datasets.py CHANGED
@@ -1,179 +1,16 @@
1
  import torch
2
- from torch.utils.data import Dataset, DataLoader, random_split
3
  import torchaudio
4
- import torchaudio.transforms as T
5
  import torch.nn.functional as F
6
  from pathlib import Path
7
  import pytorch_lightning as pl
8
- from typing import Any, List, Tuple
9
  from remfx import effects
10
- from pedalboard import (
11
- Pedalboard,
12
- Chorus,
13
- Reverb,
14
- Compressor,
15
- Phaser,
16
- Delay,
17
- Distortion,
18
- Limiter,
19
- )
20
  from tqdm import tqdm
 
21
 
22
- # https://zenodo.org/record/7044411/ -> GuitarFX
23
- # https://zenodo.org/record/3371780 -> GuitarSet
24
  # https://zenodo.org/record/1193957 -> VocalSet
25
 
26
- deterministic_effects = {
27
- "Distortion": Pedalboard([Distortion()]),
28
- "Compressor": Pedalboard([Compressor()]),
29
- "Chorus": Pedalboard([Chorus()]),
30
- "Phaser": Pedalboard([Phaser()]),
31
- "Delay": Pedalboard([Delay()]),
32
- "Reverb": Pedalboard([Reverb()]),
33
- "Limiter": Pedalboard([Limiter()]),
34
- }
35
-
36
-
37
- class GuitarFXDataset(Dataset):
38
- def __init__(
39
- self,
40
- root: str,
41
- sample_rate: int,
42
- chunk_size_in_sec: int = 3,
43
- effect_types: List[str] = None,
44
- ):
45
- super().__init__()
46
- self.wet_files = []
47
- self.dry_files = []
48
- self.chunks = []
49
- self.labels = []
50
- self.song_idx = []
51
- self.root = Path(root)
52
- self.chunk_size_in_sec = chunk_size_in_sec
53
- self.sample_rate = sample_rate
54
-
55
- if effect_types is None:
56
- effect_types = [
57
- d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
58
- ]
59
- current_file = 0
60
- for i, effect in enumerate(effect_types):
61
- for pickup in Path(self.root / effect).iterdir():
62
- wet_files = sorted(list(pickup.glob("*.wav")))
63
- dry_files = sorted(
64
- list(self.root.glob(f"Clean/{pickup.name}/**/*.wav"))
65
- )
66
- self.wet_files += wet_files
67
- self.dry_files += dry_files
68
- self.labels += [i] * len(wet_files)
69
- for audio_file in wet_files:
70
- chunk_starts, orig_sr = create_sequential_chunks(
71
- audio_file, self.chunk_size_in_sec
72
- )
73
- self.chunks += chunk_starts
74
- self.song_idx += [current_file] * len(chunk_starts)
75
- current_file += 1
76
- print(
77
- f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
78
- f"Total chunks: {len(self.chunks)}"
79
- )
80
- self.resampler = T.Resample(orig_sr, sample_rate)
81
-
82
- def __len__(self):
83
- return len(self.chunks)
84
-
85
- def __getitem__(self, idx):
86
- # Load effected and "clean" audio
87
- song_idx = self.song_idx[idx]
88
- x, sr = torchaudio.load(self.wet_files[song_idx])
89
- y, sr = torchaudio.load(self.dry_files[song_idx])
90
- effect_label = self.labels[song_idx] # Effect label
91
-
92
- chunk_start = self.chunks[idx]
93
- chunk_size_in_samples = self.chunk_size_in_sec * sr
94
- x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
95
- y = y[:, chunk_start : chunk_start + chunk_size_in_samples]
96
-
97
- resampled_x = self.resampler(x)
98
- resampled_y = self.resampler(y)
99
- # Reset chunk size to be new sample rate
100
- chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
101
- # Pad to chunk_size if needed
102
- if resampled_x.shape[-1] < chunk_size_in_samples:
103
- resampled_x = F.pad(
104
- resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
105
- )
106
- if resampled_y.shape[-1] < chunk_size_in_samples:
107
- resampled_y = F.pad(
108
- resampled_y, (0, chunk_size_in_samples - resampled_y.shape[1])
109
- )
110
- return (resampled_x, resampled_y, effect_label)
111
-
112
-
113
- class GuitarSet(Dataset):
114
- def __init__(
115
- self,
116
- root: str,
117
- sample_rate: int,
118
- chunk_size_in_sec: int = 3,
119
- effect_types: List[torch.nn.Module] = None,
120
- ):
121
- super().__init__()
122
- self.chunks = []
123
- self.song_idx = []
124
- self.root = Path(root)
125
- self.chunk_size_in_sec = chunk_size_in_sec
126
- self.files = sorted(list(self.root.glob("./**/*.wav")))
127
- self.sample_rate = sample_rate
128
- for i, audio_file in enumerate(self.files):
129
- chunk_starts, orig_sr = create_sequential_chunks(
130
- audio_file, self.chunk_size_in_sec
131
- )
132
- self.chunks += chunk_starts
133
- self.song_idx += [i] * len(chunk_starts)
134
- print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
135
- self.resampler = T.Resample(orig_sr, sample_rate)
136
- self.effect_types = effect_types
137
- self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
138
- self.mode = "train"
139
-
140
- def __len__(self):
141
- return len(self.chunks)
142
-
143
- def __getitem__(self, idx):
144
- # Load and effect audio
145
- song_idx = self.song_idx[idx]
146
- x, sr = torchaudio.load(self.files[song_idx])
147
- chunk_start = self.chunks[idx]
148
- chunk_size_in_samples = self.chunk_size_in_sec * sr
149
- x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
150
- resampled_x = self.resampler(x)
151
- # Reset chunk size to be new sample rate
152
- chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
153
- # Pad to chunk_size if needed
154
- if resampled_x.shape[-1] < chunk_size_in_samples:
155
- resampled_x = F.pad(
156
- resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
157
- )
158
-
159
- # Add random effect if train
160
- if self.mode == "train":
161
- random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
162
- effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
163
- effect = self.effect_types[effect_name]
164
- effected_input = effect(resampled_x)
165
- else:
166
- # deterministic static effect for eval
167
- effect_idx = idx % len(self.effect_types.keys())
168
- effect_name = list(self.effect_types.keys())[effect_idx]
169
- effect = deterministic_effects[effect_name]
170
- effected_input = torch.from_numpy(
171
- effect(resampled_x.numpy(), self.sample_rate)
172
- )
173
- normalized_input = self.normalize(effected_input)
174
- normalized_target = self.normalize(resampled_x)
175
- return (normalized_input, normalized_target, effect_name)
176
-
177
 
178
  class VocalSet(Dataset):
179
  def __init__(
@@ -183,7 +20,6 @@ class VocalSet(Dataset):
183
  chunk_size_in_sec: int = 3,
184
  effect_types: List[torch.nn.Module] = None,
185
  render_files: bool = True,
186
- output_root: str = "processed",
187
  mode: str = "train",
188
  ):
189
  super().__init__()
@@ -199,14 +35,15 @@ class VocalSet(Dataset):
199
  self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
200
  self.effect_types = effect_types
201
 
202
- self.output_root = Path(output_root)
203
 
204
  self.num_chunks = 0
205
  print("Total files:", len(self.files))
206
  print("Processing files...")
207
  if render_files:
208
- self.output_root.mkdir(parents=True, exist_ok=True)
209
- for i, audio_file in tqdm(enumerate(self.files)):
 
210
  chunks, orig_sr = create_sequential_chunks(
211
  audio_file, self.chunk_size_in_sec
212
  )
@@ -220,14 +57,16 @@ class VocalSet(Dataset):
220
  resampled_chunk,
221
  (0, chunk_size_in_samples - resampled_chunk.shape[1]),
222
  )
 
223
  effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
224
  effect_name = list(self.effect_types.keys())[int(effect_idx)]
225
  effect = self.effect_types[effect_name]
226
  effected_input = effect(resampled_chunk)
 
227
  normalized_input = self.normalize(effected_input)
228
  normalized_target = self.normalize(resampled_chunk)
229
 
230
- output_dir = self.output_root / str(self.num_chunks)
231
  output_dir.mkdir(exist_ok=True)
232
  torchaudio.save(
233
  output_dir / "input.wav", normalized_input, self.sample_rate
@@ -235,9 +74,10 @@ class VocalSet(Dataset):
235
  torchaudio.save(
236
  output_dir / "target.wav", normalized_target, self.sample_rate
237
  )
 
238
  self.num_chunks += 1
239
  else:
240
- self.num_chunks = len(list(self.output_root.glob("./**/*.wav")))
241
 
242
  print(
243
  f"Found {len(self.files)} {self.mode} files .\n"
@@ -248,95 +88,12 @@ class VocalSet(Dataset):
248
  return self.num_chunks
249
 
250
  def __getitem__(self, idx):
251
- # Load audio
252
- input_file = self.output_root / str(idx) / "input.wav"
253
- target_file = self.output_root / str(idx) / "target.wav"
254
  input, sr = torchaudio.load(input_file)
255
  target, sr = torchaudio.load(target_file)
256
- return (input, target, "")
257
-
258
-
259
- def create_random_chunks(
260
- audio_file: str, chunk_size: int, num_chunks: int
261
- ) -> Tuple[List[Tuple[int, int]], int]:
262
- """Create num_chunks random chunks of size chunk_size (seconds)
263
- from an audio file.
264
- Return sample_index of start of each chunk and original sr
265
- """
266
- audio, sr = torchaudio.load(audio_file)
267
- chunk_size_in_samples = chunk_size * sr
268
- if chunk_size_in_samples >= audio.shape[-1]:
269
- chunk_size_in_samples = audio.shape[-1] - 1
270
- chunks = []
271
- for i in range(num_chunks):
272
- start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
273
- chunks.append(start)
274
- return chunks, sr
275
-
276
-
277
- def create_sequential_chunks(
278
- audio_file: str, chunk_size: int
279
- ) -> Tuple[List[Tuple[int, int]], int]:
280
- """Create sequential chunks of size chunk_size (seconds) from an audio file.
281
- Return sample_index of start of each chunk and original sr
282
- """
283
- chunks = []
284
- audio, sr = torchaudio.load(audio_file)
285
- chunk_size_in_samples = chunk_size * sr
286
- chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
287
- for start in chunk_starts:
288
- if start + chunk_size_in_samples > audio.shape[-1]:
289
- break
290
- chunks.append(audio[:, start : start + chunk_size_in_samples])
291
- return chunks, sr
292
-
293
-
294
- class Datamodule(pl.LightningDataModule):
295
- def __init__(
296
- self,
297
- dataset,
298
- *,
299
- val_split: float,
300
- batch_size: int,
301
- num_workers: int,
302
- pin_memory: bool = False,
303
- **kwargs: int,
304
- ) -> None:
305
- super().__init__()
306
- self.dataset = dataset
307
- self.val_split = val_split
308
- self.batch_size = batch_size
309
- self.num_workers = num_workers
310
- self.pin_memory = pin_memory
311
- self.data_train: Any = None
312
- self.data_val: Any = None
313
-
314
- def setup(self, stage: Any = None) -> None:
315
- split = [1.0 - self.val_split, self.val_split]
316
- train_size = round(split[0] * len(self.dataset))
317
- val_size = round(split[1] * len(self.dataset))
318
- self.data_train, self.data_val = random_split(
319
- self.dataset, [train_size, val_size]
320
- )
321
- self.data_val.dataset.mode = "val"
322
-
323
- def train_dataloader(self) -> DataLoader:
324
- return DataLoader(
325
- dataset=self.data_train,
326
- batch_size=self.batch_size,
327
- num_workers=self.num_workers,
328
- pin_memory=self.pin_memory,
329
- shuffle=True,
330
- )
331
-
332
- def val_dataloader(self) -> DataLoader:
333
- return DataLoader(
334
- dataset=self.data_val,
335
- batch_size=self.batch_size,
336
- num_workers=self.num_workers,
337
- pin_memory=self.pin_memory,
338
- shuffle=False,
339
- )
340
 
341
 
342
  class VocalSetDatamodule(pl.LightningDataModule):
@@ -344,6 +101,7 @@ class VocalSetDatamodule(pl.LightningDataModule):
344
  self,
345
  train_dataset,
346
  val_dataset,
 
347
  *,
348
  batch_size: int,
349
  num_workers: int,
@@ -353,6 +111,7 @@ class VocalSetDatamodule(pl.LightningDataModule):
353
  super().__init__()
354
  self.train_dataset = train_dataset
355
  self.val_dataset = val_dataset
 
356
  self.batch_size = batch_size
357
  self.num_workers = num_workers
358
  self.pin_memory = pin_memory
@@ -377,3 +136,12 @@ class VocalSetDatamodule(pl.LightningDataModule):
377
  pin_memory=self.pin_memory,
378
  shuffle=False,
379
  )
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
  import torchaudio
 
4
  import torch.nn.functional as F
5
  from pathlib import Path
6
  import pytorch_lightning as pl
7
+ from typing import Any, List
8
  from remfx import effects
 
 
 
 
 
 
 
 
 
 
9
  from tqdm import tqdm
10
+ from remfx.utils import create_sequential_chunks
11
 
 
 
12
  # https://zenodo.org/record/1193957 -> VocalSet
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  class VocalSet(Dataset):
16
  def __init__(
 
20
  chunk_size_in_sec: int = 3,
21
  effect_types: List[torch.nn.Module] = None,
22
  render_files: bool = True,
 
23
  mode: str = "train",
24
  ):
25
  super().__init__()
 
35
  self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
36
  self.effect_types = effect_types
37
 
38
+ self.processed_root = self.root / "processed" / self.mode
39
 
40
  self.num_chunks = 0
41
  print("Total files:", len(self.files))
42
  print("Processing files...")
43
  if render_files:
44
+ # Split audio file into chunks, resample, then apply random effects
45
+ self.processed_root.mkdir(parents=True, exist_ok=True)
46
+ for audio_file in tqdm(self.files, total=len(self.files)):
47
  chunks, orig_sr = create_sequential_chunks(
48
  audio_file, self.chunk_size_in_sec
49
  )
 
57
  resampled_chunk,
58
  (0, chunk_size_in_samples - resampled_chunk.shape[1]),
59
  )
60
+ # Apply effect
61
  effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
62
  effect_name = list(self.effect_types.keys())[int(effect_idx)]
63
  effect = self.effect_types[effect_name]
64
  effected_input = effect(resampled_chunk)
65
+ # Normalize
66
  normalized_input = self.normalize(effected_input)
67
  normalized_target = self.normalize(resampled_chunk)
68
 
69
+ output_dir = self.processed_root / str(self.num_chunks)
70
  output_dir.mkdir(exist_ok=True)
71
  torchaudio.save(
72
  output_dir / "input.wav", normalized_input, self.sample_rate
 
74
  torchaudio.save(
75
  output_dir / "target.wav", normalized_target, self.sample_rate
76
  )
77
+ torch.save(effect_name, output_dir / "effect_name.pt")
78
  self.num_chunks += 1
79
  else:
80
+ self.num_chunks = len(list(self.processed_root.iterdir()))
81
 
82
  print(
83
  f"Found {len(self.files)} {self.mode} files .\n"
 
88
  return self.num_chunks
89
 
90
  def __getitem__(self, idx):
91
+ input_file = self.processed_root / str(idx) / "input.wav"
92
+ target_file = self.processed_root / str(idx) / "target.wav"
93
+ effect_name = torch.load(self.processed_root / str(idx) / "effect_name.pt")
94
  input, sr = torchaudio.load(input_file)
95
  target, sr = torchaudio.load(target_file)
96
+ return (input, target, effect_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
  class VocalSetDatamodule(pl.LightningDataModule):
 
101
  self,
102
  train_dataset,
103
  val_dataset,
104
+ test_dataset,
105
  *,
106
  batch_size: int,
107
  num_workers: int,
 
111
  super().__init__()
112
  self.train_dataset = train_dataset
113
  self.val_dataset = val_dataset
114
+ self.test_dataset = test_dataset
115
  self.batch_size = batch_size
116
  self.num_workers = num_workers
117
  self.pin_memory = pin_memory
 
136
  pin_memory=self.pin_memory,
137
  shuffle=False,
138
  )
139
+
140
+ def test_dataloader(self) -> DataLoader:
141
+ return DataLoader(
142
+ dataset=self.test_dataset,
143
+ batch_size=self.batch_size,
144
+ num_workers=self.num_workers,
145
+ pin_memory=self.pin_memory,
146
+ shuffle=False,
147
+ )
remfx/models.py CHANGED
@@ -7,44 +7,12 @@ from audio_diffusion_pytorch import DiffusionModel
7
  from auraloss.time import SISDRLoss
8
  from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
9
  from torch.nn import L1Loss
10
- from frechet_audio_distance import FrechetAudioDistance
11
- import numpy as np
12
 
13
  from umx.openunmix.model import OpenUnmix, Separator
14
  from torchaudio.models import HDemucs
15
 
16
 
17
- class FADLoss(torch.nn.Module):
18
- def __init__(self, sample_rate: float):
19
- super().__init__()
20
- self.fad = FrechetAudioDistance(
21
- use_pca=False, use_activation=False, verbose=False
22
- )
23
- self.fad.model = self.fad.model.to("cpu")
24
- self.sr = sample_rate
25
-
26
- def forward(self, audio_background, audio_eval):
27
- embds_background = []
28
- embds_eval = []
29
- for sample in audio_background:
30
- embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
31
- embds_background.append(embd.cpu().detach().numpy())
32
- for sample in audio_eval:
33
- embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
34
- embds_eval.append(embd.cpu().detach().numpy())
35
- embds_background = np.concatenate(embds_background, axis=0)
36
- embds_eval = np.concatenate(embds_eval, axis=0)
37
- mu_background, sigma_background = self.fad.calculate_embd_statistics(
38
- embds_background
39
- )
40
- mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
41
-
42
- fad_score = self.fad.calculate_frechet_distance(
43
- mu_background, sigma_background, mu_eval, sigma_eval
44
- )
45
- return fad_score
46
-
47
-
48
  class RemFXModel(pl.LightningModule):
49
  def __init__(
50
  self,
@@ -97,6 +65,10 @@ class RemFXModel(pl.LightningModule):
97
  loss = self.common_step(batch, batch_idx, mode="valid")
98
  return loss
99
 
 
 
 
 
100
  def common_step(self, batch, batch_idx, mode: str = "train"):
101
  loss, output = self.model(batch)
102
  self.log(f"{mode}_loss", loss)
@@ -121,6 +93,7 @@ class RemFXModel(pl.LightningModule):
121
  return loss
122
 
123
  def on_train_batch_start(self, batch, batch_idx):
 
124
  if self.log_train_audio:
125
  x, y, label = batch
126
  # Concat samples together for easier viewing in dashboard
@@ -143,48 +116,47 @@ class RemFXModel(pl.LightningModule):
143
  )
144
  self.log_train_audio = False
145
 
146
- def on_validation_epoch_start(self):
147
- self.log_next = True
148
-
149
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
150
- if self.log_next:
151
- x, target, label = batch
152
- # Log Input Metrics
153
- for metric in self.metrics:
154
- # SISDR returns negative values, so negate them
155
- if metric == "SISDR":
156
- negate = -1
157
- else:
158
- negate = 1
159
- self.log(
160
- f"Input_{metric}",
161
- negate * self.metrics[metric](x, target),
162
- on_step=False,
163
- on_epoch=True,
164
- logger=True,
165
- prog_bar=True,
166
- sync_dist=True,
167
- )
168
-
169
- self.model.eval()
170
- with torch.no_grad():
171
- y = self.model.sample(x)
172
 
173
- # Concat samples together for easier viewing in dashboard
174
- # 2 seconds of silence between each sample
175
- silence = torch.zeros_like(x)
176
- silence = silence[:, : self.sample_rate * 2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
179
- log_wandb_audio_batch(
180
- logger=self.logger,
181
- id="prediction_input_target",
182
- samples=concat_samples.cpu(),
183
- sampling_rate=self.sample_rate,
184
- caption=f"Epoch {self.current_epoch}",
185
- )
186
- self.log_next = False
187
- self.model.train()
188
 
189
 
190
  class OpenUnmixModel(torch.nn.Module):
 
7
  from auraloss.time import SISDRLoss
8
  from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
9
  from torch.nn import L1Loss
10
+ from remfx.utils import FADLoss
 
11
 
12
  from umx.openunmix.model import OpenUnmix, Separator
13
  from torchaudio.models import HDemucs
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class RemFXModel(pl.LightningModule):
17
  def __init__(
18
  self,
 
65
  loss = self.common_step(batch, batch_idx, mode="valid")
66
  return loss
67
 
68
+ def test_step(self, batch, batch_idx):
69
+ loss = self.common_step(batch, batch_idx, mode="test")
70
+ return loss
71
+
72
  def common_step(self, batch, batch_idx, mode: str = "train"):
73
  loss, output = self.model(batch)
74
  self.log(f"{mode}_loss", loss)
 
93
  return loss
94
 
95
  def on_train_batch_start(self, batch, batch_idx):
96
+ # Log initial audio
97
  if self.log_train_audio:
98
  x, y, label = batch
99
  # Concat samples together for easier viewing in dashboard
 
116
  )
117
  self.log_train_audio = False
118
 
 
 
 
119
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
120
+ x, target, label = batch
121
+ # Log Input Metrics
122
+ for metric in self.metrics:
123
+ # SISDR returns negative values, so negate them
124
+ if metric == "SISDR":
125
+ negate = -1
126
+ else:
127
+ negate = 1
128
+ self.log(
129
+ f"Input_{metric}",
130
+ negate * self.metrics[metric](x, target),
131
+ on_step=False,
132
+ on_epoch=True,
133
+ logger=True,
134
+ prog_bar=True,
135
+ sync_dist=True,
136
+ )
 
 
 
 
 
137
 
138
+ self.model.eval()
139
+ with torch.no_grad():
140
+ y = self.model.sample(x)
141
+
142
+ # Concat samples together for easier viewing in dashboard
143
+ # 2 seconds of silence between each sample
144
+ silence = torch.zeros_like(x)
145
+ silence = silence[:, : self.sample_rate * 2]
146
+
147
+ concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
148
+ log_wandb_audio_batch(
149
+ logger=self.logger,
150
+ id="prediction_input_target",
151
+ samples=concat_samples.cpu(),
152
+ sampling_rate=self.sample_rate,
153
+ caption=f"Epoch {self.current_epoch}",
154
+ )
155
+ self.log_next = False
156
+ self.model.train()
157
 
158
+ def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
159
+ return self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
 
 
 
 
 
 
 
 
160
 
161
 
162
  class OpenUnmixModel(torch.nn.Module):
remfx/utils.py CHANGED
@@ -1,8 +1,12 @@
1
  import logging
2
- from typing import List
3
  import pytorch_lightning as pl
4
  from omegaconf import DictConfig
5
  from pytorch_lightning.utilities import rank_zero_only
 
 
 
 
6
 
7
 
8
  def get_logger(name=__name__) -> logging.Logger:
@@ -69,3 +73,69 @@ def log_hyperparameters(
69
  hparams["callbacks"] = config["callbacks"]
70
 
71
  logger.experiment.config.update(hparams)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from typing import List, Tuple
3
  import pytorch_lightning as pl
4
  from omegaconf import DictConfig
5
  from pytorch_lightning.utilities import rank_zero_only
6
+ from frechet_audio_distance import FrechetAudioDistance
7
+ import numpy as np
8
+ import torch
9
+ import torchaudio
10
 
11
 
12
  def get_logger(name=__name__) -> logging.Logger:
 
73
  hparams["callbacks"] = config["callbacks"]
74
 
75
  logger.experiment.config.update(hparams)
76
+
77
+
78
+ class FADLoss(torch.nn.Module):
79
+ def __init__(self, sample_rate: float):
80
+ super().__init__()
81
+ self.fad = FrechetAudioDistance(
82
+ use_pca=False, use_activation=False, verbose=False
83
+ )
84
+ self.fad.model = self.fad.model.to("cpu")
85
+ self.sr = sample_rate
86
+
87
+ def forward(self, audio_background, audio_eval):
88
+ embds_background = []
89
+ embds_eval = []
90
+ for sample in audio_background:
91
+ embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
92
+ embds_background.append(embd.cpu().detach().numpy())
93
+ for sample in audio_eval:
94
+ embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
95
+ embds_eval.append(embd.cpu().detach().numpy())
96
+ embds_background = np.concatenate(embds_background, axis=0)
97
+ embds_eval = np.concatenate(embds_eval, axis=0)
98
+ mu_background, sigma_background = self.fad.calculate_embd_statistics(
99
+ embds_background
100
+ )
101
+ mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
102
+
103
+ fad_score = self.fad.calculate_frechet_distance(
104
+ mu_background, sigma_background, mu_eval, sigma_eval
105
+ )
106
+ return fad_score
107
+
108
+
109
+ def create_random_chunks(
110
+ audio_file: str, chunk_size: int, num_chunks: int
111
+ ) -> Tuple[List[Tuple[int, int]], int]:
112
+ """Create num_chunks random chunks of size chunk_size (seconds)
113
+ from an audio file.
114
+ Return sample_index of start of each chunk and original sr
115
+ """
116
+ audio, sr = torchaudio.load(audio_file)
117
+ chunk_size_in_samples = chunk_size * sr
118
+ if chunk_size_in_samples >= audio.shape[-1]:
119
+ chunk_size_in_samples = audio.shape[-1] - 1
120
+ chunks = []
121
+ for i in range(num_chunks):
122
+ start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
123
+ chunks.append(start)
124
+ return chunks, sr
125
+
126
+
127
+ def create_sequential_chunks(
128
+ audio_file: str, chunk_size: int
129
+ ) -> Tuple[List[Tuple[int, int]], int]:
130
+ """Create sequential chunks of size chunk_size (seconds) from an audio file.
131
+ Return sample_index of start of each chunk and original sr
132
+ """
133
+ chunks = []
134
+ audio, sr = torchaudio.load(audio_file)
135
+ chunk_size_in_samples = chunk_size * sr
136
+ chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
137
+ for start in chunk_starts:
138
+ if start + chunk_size_in_samples > audio.shape[-1]:
139
+ break
140
+ chunks.append(audio[:, start : start + chunk_size_in_samples])
141
+ return chunks, sr
shell_vars.sh CHANGED
@@ -1,4 +1,3 @@
1
  export DATASET_ROOT="./data/VocalSet"
2
- export OUTPUT_ROOT="/scratch/VocalSet/processed"
3
  export WANDB_PROJECT="RemFX"
4
  export WANDB_ENTITY="mattricesound"
 
1
  export DATASET_ROOT="./data/VocalSet"
 
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"