mattricesound commited on
Commit
c7866f1
β€’
1 Parent(s): 848b108

Refactor dataset to apply multiple effects at a time

Browse files
README.md CHANGED
@@ -40,6 +40,6 @@ Experiment dictates data, ckpt dictates model
40
  `python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
41
 
42
  ## Misc.
43
- By default, files are rendered to `input_dir / processed / train/val/test`.
44
  To skip rendering files (use previously rendered), add `render_files=False` to the command-line (added to test by default).
45
  To change the rendered location, add `render_root={path/to/dir}` to the command-line (use this for train and test)
 
40
  `python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
41
 
42
  ## Misc.
43
+ By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
44
  To skip rendering files (use previously rendered), add `render_files=False` to the command-line (added to test by default).
45
  To change the rendered location, add `render_root={path/to/dir}` to the command-line (use this for train and test)
cfg/{effects β†’ applied_effects}/all.yaml RENAMED
@@ -1,5 +1,5 @@
1
  # @package _global_
2
- effects:
3
  Chorus:
4
  _target_: remfx.effects.RandomPedalboardChorus
5
  sample_rate: ${sample_rate}
 
1
  # @package _global_
2
+ applied_effects:
3
  Chorus:
4
  _target_: remfx.effects.RandomPedalboardChorus
5
  sample_rate: ${sample_rate}
cfg/config.yaml CHANGED
@@ -1,15 +1,17 @@
1
  defaults:
2
  - _self_
3
  - model: null
4
- - effects: null
 
5
 
 
6
  seed: 12345
7
  train: True
8
  sample_rate: 48000
9
  chunk_size: 262144 # 5.5s
10
  logs_dir: "./logs"
11
  render_files: True
12
- render_root: "./data/processed"
13
 
14
  callbacks:
15
  model_checkpoint:
@@ -33,7 +35,9 @@ datamodule:
33
  root: ${oc.env:DATASET_ROOT}
34
  chunk_size: ${chunk_size}
35
  mode: "train"
36
- effect_types: ${effects}
 
 
37
  render_files: ${render_files}
38
  render_root: ${render_root}
39
  val_dataset:
@@ -42,7 +46,9 @@ datamodule:
42
  root: ${oc.env:DATASET_ROOT}
43
  chunk_size: ${chunk_size}
44
  mode: "val"
45
- effect_types: ${effects}
 
 
46
  render_files: ${render_files}
47
  render_root: ${render_root}
48
  test_dataset:
@@ -51,7 +57,9 @@ datamodule:
51
  root: ${oc.env:DATASET_ROOT}
52
  chunk_size: ${chunk_size}
53
  mode: "test"
54
- effect_types: ${effects}
 
 
55
  render_files: ${render_files}
56
  render_root: ${render_root}
57
 
 
1
  defaults:
2
  - _self_
3
  - model: null
4
+ - applied_effects: null
5
+ - effect_to_remove: null
6
 
7
+ max_effects_per_file: 3
8
  seed: 12345
9
  train: True
10
  sample_rate: 48000
11
  chunk_size: 262144 # 5.5s
12
  logs_dir: "./logs"
13
  render_files: True
14
+ render_root: "./data"
15
 
16
  callbacks:
17
  model_checkpoint:
 
35
  root: ${oc.env:DATASET_ROOT}
36
  chunk_size: ${chunk_size}
37
  mode: "train"
38
+ applied_effects: ${applied_effects}
39
+ effect_to_remove: ${effect_to_remove}
40
+ max_effects_per_file: ${max_effects_per_file}
41
  render_files: ${render_files}
42
  render_root: ${render_root}
43
  val_dataset:
 
46
  root: ${oc.env:DATASET_ROOT}
47
  chunk_size: ${chunk_size}
48
  mode: "val"
49
+ applied_effects: ${applied_effects}
50
+ effect_to_remove: ${effect_to_remove}
51
+ max_effects_per_file: ${max_effects_per_file}
52
  render_files: ${render_files}
53
  render_root: ${render_root}
54
  test_dataset:
 
57
  root: ${oc.env:DATASET_ROOT}
58
  chunk_size: ${chunk_size}
59
  mode: "test"
60
+ applied_effects: ${applied_effects}
61
+ effect_to_remove: ${effect_to_remove}
62
+ max_effects_per_file: ${max_effects_per_file}
63
  render_files: ${render_files}
64
  render_root: ${render_root}
65
 
cfg/effect_to_remove/all.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ effect_to_remove:
3
+ Chorus:
4
+ _target_: remfx.effects.RandomPedalboardChorus
5
+ sample_rate: ${sample_rate}
6
+ min_depth: 0.2
7
+ min_mix: 0.3
8
+ Distortion:
9
+ _target_: remfx.effects.RandomPedalboardDistortion
10
+ sample_rate: ${sample_rate}
11
+ min_drive_db: 10
12
+ max_drive_db: 50
13
+ Compressor:
14
+ _target_: remfx.effects.RandomPedalboardCompressor
15
+ sample_rate: ${sample_rate}
16
+ min_threshold_db: -42.0
17
+ max_threshold_db: -20.0
18
+ min_ratio: 1.5
19
+ max_ratio: 6.0
20
+ Reverb:
21
+ _target_: remfx.effects.RandomPedalboardReverb
22
+ sample_rate: ${sample_rate}
23
+ min_room_size: 0.3
24
+ max_room_size: 1.0
25
+ min_damping: 0.2
26
+ max_damping: 1.0
27
+ min_wet_dry: 0.2
28
+ max_wet_dry: 0.8
29
+ min_width: 0.2
30
+ max_width: 1.0
31
+
cfg/{effects β†’ effect_to_remove}/chorus.yaml RENAMED
@@ -1,5 +1,5 @@
1
  # @package _global_
2
- effects:
3
  Chorus:
4
  _target_: remfx.effects.RandomPedalboardChorus
5
  sample_rate: ${sample_rate}
 
1
  # @package _global_
2
+ effect_to_remove:
3
  Chorus:
4
  _target_: remfx.effects.RandomPedalboardChorus
5
  sample_rate: ${sample_rate}
cfg/{effects β†’ effect_to_remove}/compressor.yaml RENAMED
@@ -1,5 +1,5 @@
1
  # @package _global_
2
- effects:
3
  Compressor:
4
  _target_: remfx.effects.RandomPedalboardCompressor
5
  sample_rate: ${sample_rate}
 
1
  # @package _global_
2
+ effect_to_remove:
3
  Compressor:
4
  _target_: remfx.effects.RandomPedalboardCompressor
5
  sample_rate: ${sample_rate}
cfg/{effects β†’ effect_to_remove}/distortion.yaml RENAMED
@@ -1,5 +1,5 @@
1
  # @package _global_
2
- effects:
3
  Distortion:
4
  _target_: remfx.effects.RandomPedalboardDistortion
5
  sample_rate: ${sample_rate}
 
1
  # @package _global_
2
+ effect_to_remove:
3
  Distortion:
4
  _target_: remfx.effects.RandomPedalboardDistortion
5
  sample_rate: ${sample_rate}
cfg/{effects β†’ effect_to_remove}/reverb.yaml RENAMED
@@ -1,5 +1,5 @@
1
  # @package _global_
2
- effects:
3
  Reverb:
4
  _target_: remfx.effects.RandomPedalboardReverb
5
  sample_rate: ${sample_rate}
 
1
  # @package _global_
2
+ effect_to_remove:
3
  Reverb:
4
  _target_: remfx.effects.RandomPedalboardReverb
5
  sample_rate: ${sample_rate}
cfg/exp/demucs_all.yaml CHANGED
@@ -1,4 +1,5 @@
1
  # @package _global_
2
  defaults:
3
  - override /model: demucs
4
- - override /effects: all
 
 
1
  # @package _global_
2
  defaults:
3
  - override /model: demucs
4
+ - override /applied_effects: all
5
+ - override /effect_to_remove: all
cfg/exp/demucs_chorus.yaml CHANGED
@@ -1,4 +1,6 @@
1
  # @package _global_
2
  defaults:
3
  - override /model: demucs
4
- - override /effects: chorus
 
 
 
1
  # @package _global_
2
  defaults:
3
  - override /model: demucs
4
+ - override /applied_effects: all
5
+ - override /effect_to_remove: chorus
6
+
cfg/exp/demucs_compressor.yaml CHANGED
@@ -1,4 +1,5 @@
1
  # @package _global_
2
  defaults:
3
  - override /model: demucs
4
- - override /effects: compressor
 
 
1
  # @package _global_
2
  defaults:
3
  - override /model: demucs
4
+ - override /applied_effects: all
5
+ - override /effect_to_remove: distortion
cfg/exp/demucs_distortion.yaml CHANGED
@@ -1,4 +1,5 @@
1
  # @package _global_
2
  defaults:
3
  - override /model: demucs
4
- - override /effects: distortion
 
 
1
  # @package _global_
2
  defaults:
3
  - override /model: demucs
4
+ - override /applied_effects: all
5
+ - override /effect_to_remove: distortion
cfg/exp/demucs_reverb.yaml CHANGED
@@ -1,4 +1,5 @@
1
  # @package _global_
2
  defaults:
3
  - override /model: demucs
4
- - override /effects: reverb
 
 
1
  # @package _global_
2
  defaults:
3
  - override /model: demucs
4
+ - override /applied_effects: all
5
+ - override /effect_to_remove: reverb
cfg/exp/umx_all.yaml CHANGED
@@ -1,4 +1,5 @@
1
  # @package _global_
2
  defaults:
3
  - override /model: umx
4
- - override /effects: all
 
 
1
  # @package _global_
2
  defaults:
3
  - override /model: umx
4
+ - override /applied_effects: all
5
+ - override /effect_to_remove: all
cfg/exp/umx_chorus.yaml CHANGED
@@ -1,4 +1,6 @@
1
  # @package _global_
2
  defaults:
3
  - override /model: umx
4
- - override /effects: chorus
 
 
 
1
  # @package _global_
2
  defaults:
3
  - override /model: umx
4
+ - override /applied_effects: all
5
+ - override /effect_to_remove: chorus
6
+
cfg/exp/umx_compressor.yaml CHANGED
@@ -1,4 +1,5 @@
1
  # @package _global_
2
  defaults:
3
  - override /model: umx
4
- - override /effects: compressor
 
 
1
  # @package _global_
2
  defaults:
3
  - override /model: umx
4
+ - override /applied_effects: all
5
+ - override /effect_to_remove: distortion
cfg/exp/umx_distortion.yaml CHANGED
@@ -1,4 +1,5 @@
1
  # @package _global_
2
  defaults:
3
  - override /model: umx
4
- - override /effects: distortion
 
 
1
  # @package _global_
2
  defaults:
3
  - override /model: umx
4
+ - override /applied_effects: all
5
+ - override /effect_to_remove: distortion
cfg/exp/umx_reverb.yaml CHANGED
@@ -1,4 +1,5 @@
1
  # @package _global_
2
  defaults:
3
  - override /model: umx
4
- - override /effects: reverb
 
 
1
  # @package _global_
2
  defaults:
3
  - override /model: umx
4
+ - override /applied_effects: all
5
+ - override /effect_to_remove: reverb
remfx/datasets.py CHANGED
@@ -1,16 +1,20 @@
1
  import torch
2
  from torch.utils.data import Dataset, DataLoader
3
- import torchaudio
4
  import torch.nn.functional as F
 
5
  from pathlib import Path
6
  import pytorch_lightning as pl
7
- from typing import Any, List
 
8
  from remfx import effects
9
  from tqdm import tqdm
10
  from remfx.utils import create_sequential_chunks
 
11
 
12
  # https://zenodo.org/record/1193957 -> VocalSet
13
 
 
 
14
 
15
  class VocalSet(Dataset):
16
  def __init__(
@@ -18,7 +22,9 @@ class VocalSet(Dataset):
18
  root: str,
19
  sample_rate: int,
20
  chunk_size: int = 3,
21
- effect_types: List[torch.nn.Module] = None,
 
 
22
  render_files: bool = True,
23
  render_root: str = None,
24
  mode: str = "train",
@@ -31,22 +37,36 @@ class VocalSet(Dataset):
31
  self.chunk_size = chunk_size
32
  self.sample_rate = sample_rate
33
  self.mode = mode
34
-
 
35
  mode_path = self.root / self.mode
36
  self.files = sorted(list(mode_path.glob("./**/*.wav")))
37
  self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
38
- self.effect_types = effect_types
39
- effect_str = "_".join([e for e in self.effect_types])
40
- self.processed_root = self.render_root / "processed" / effect_str / self.mode
41
- if self.processed_root.exists():
 
 
 
 
42
  print("Found processed files.")
43
- render_files = False
 
 
 
 
 
 
 
 
 
44
  self.num_chunks = 0
45
  print("Total files:", len(self.files))
46
  print("Processing files...")
47
  if render_files:
48
  # Split audio file into chunks, resample, then apply random effects
49
- self.processed_root.mkdir(parents=True, exist_ok=True)
50
  for audio_file in tqdm(self.files, total=len(self.files)):
51
  chunks, orig_sr = create_sequential_chunks(audio_file, self.chunk_size)
52
  for chunk in chunks:
@@ -56,27 +76,16 @@ class VocalSet(Dataset):
56
  if resampled_chunk.shape[-1] < chunk_size:
57
  # Skip if chunk is too small
58
  continue
59
- # Apply effect
60
- effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
61
- effect_name = list(self.effect_types.keys())[int(effect_idx)]
62
- effect = self.effect_types[effect_name]
63
- effected_input = effect(resampled_chunk)
64
- # Normalize
65
- normalized_input = self.normalize(effected_input)
66
- normalized_target = self.normalize(resampled_chunk)
67
-
68
- output_dir = self.processed_root / str(self.num_chunks)
69
  output_dir.mkdir(exist_ok=True)
70
- torchaudio.save(
71
- output_dir / "input.wav", normalized_input, self.sample_rate
72
- )
73
- torchaudio.save(
74
- output_dir / "target.wav", normalized_target, self.sample_rate
75
- )
76
- torch.save(effect_name, output_dir / "effect_name.pt")
77
  self.num_chunks += 1
78
  else:
79
- self.num_chunks = len(list(self.processed_root.iterdir()))
80
 
81
  print(
82
  f"Found {len(self.files)} {self.mode} files .\n"
@@ -87,13 +96,47 @@ class VocalSet(Dataset):
87
  return self.num_chunks
88
 
89
  def __getitem__(self, idx):
90
- input_file = self.processed_root / str(idx) / "input.wav"
91
- target_file = self.processed_root / str(idx) / "target.wav"
92
- effect_name = torch.load(self.processed_root / str(idx) / "effect_name.pt")
93
  input, sr = torchaudio.load(input_file)
94
  target, sr = torchaudio.load(target_file)
95
  return (input, target, effect_name)
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  class VocalSetDatamodule(pl.LightningDataModule):
99
  def __init__(
 
1
  import torch
2
  from torch.utils.data import Dataset, DataLoader
 
3
  import torch.nn.functional as F
4
+ import torchaudio
5
  from pathlib import Path
6
  import pytorch_lightning as pl
7
+ import sys
8
+ from typing import Any, Dict
9
  from remfx import effects
10
  from tqdm import tqdm
11
  from remfx.utils import create_sequential_chunks
12
+ import shutil
13
 
14
  # https://zenodo.org/record/1193957 -> VocalSet
15
 
16
+ ALL_EFFECTS = effects.Pedalboard_Effects
17
+
18
 
19
  class VocalSet(Dataset):
20
  def __init__(
 
22
  root: str,
23
  sample_rate: int,
24
  chunk_size: int = 3,
25
+ applied_effects: Dict[str, torch.nn.Module] = None,
26
+ effect_to_remove: Dict[str, torch.nn.Module] = None,
27
+ max_effects_per_file: int = 1,
28
  render_files: bool = True,
29
  render_root: str = None,
30
  mode: str = "train",
 
37
  self.chunk_size = chunk_size
38
  self.sample_rate = sample_rate
39
  self.mode = mode
40
+ self.max_effects_per_file = max_effects_per_file
41
+ self.effect_to_remove = effect_to_remove
42
  mode_path = self.root / self.mode
43
  self.files = sorted(list(mode_path.glob("./**/*.wav")))
44
  self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
45
+ self.applied_effects = applied_effects
46
+ self.effect_to_remove_name = list(effect_to_remove.keys())[0]
47
+
48
+ effect_str = "_".join([e for e in self.applied_effects])
49
+ effect_str += f"_{self.effect_to_remove_name}"
50
+ self.proc_root = self.render_root / "processed" / effect_str / self.mode
51
+
52
+ if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
53
  print("Found processed files.")
54
+ if render_files:
55
+ re_render = input(
56
+ "WARNING: By default, will re-render files.\n"
57
+ "Set render_files=False to skip re-rendering.\n"
58
+ "Are you sure you want to re-render? (y/n): "
59
+ )
60
+ if re_render != "y":
61
+ sys.exit()
62
+ shutil.rmtree(self.proc_root)
63
+
64
  self.num_chunks = 0
65
  print("Total files:", len(self.files))
66
  print("Processing files...")
67
  if render_files:
68
  # Split audio file into chunks, resample, then apply random effects
69
+ self.proc_root.mkdir(parents=True, exist_ok=True)
70
  for audio_file in tqdm(self.files, total=len(self.files)):
71
  chunks, orig_sr = create_sequential_chunks(audio_file, self.chunk_size)
72
  for chunk in chunks:
 
76
  if resampled_chunk.shape[-1] < chunk_size:
77
  # Skip if chunk is too small
78
  continue
79
+
80
+ x, y, effect = self.process_effects(resampled_chunk)
81
+ output_dir = self.proc_root / str(self.num_chunks)
 
 
 
 
 
 
 
82
  output_dir.mkdir(exist_ok=True)
83
+ torchaudio.save(output_dir / "input.wav", x, self.sample_rate)
84
+ torchaudio.save(output_dir / "target.wav", y, self.sample_rate)
85
+ torch.save(effect, output_dir / "effect.pt")
 
 
 
 
86
  self.num_chunks += 1
87
  else:
88
+ self.num_chunks = len(list(self.proc_root.iterdir()))
89
 
90
  print(
91
  f"Found {len(self.files)} {self.mode} files .\n"
 
96
  return self.num_chunks
97
 
98
  def __getitem__(self, idx):
99
+ input_file = self.proc_root / str(idx) / "input.wav"
100
+ target_file = self.proc_root / str(idx) / "target.wav"
101
+ effect_name = torch.load(self.proc_root / str(idx) / "effect.pt")
102
  input, sr = torchaudio.load(input_file)
103
  target, sr = torchaudio.load(target_file)
104
  return (input, target, effect_name)
105
 
106
+ def process_effects(self, dry: torch.Tensor):
107
+ # Apply random number of effects up to num_effects - 1 (excluding effect_to_remove)
108
+ if self.max_effects_per_file > 1:
109
+ num_effects = torch.randint(self.max_effects_per_file - 1, (1,)).item()
110
+ # Remove effect to remove from applied effects if present
111
+ self.applied_effects.pop(self.effect_to_remove_name, None)
112
+
113
+ # Choose random effects to apply
114
+ effect_indices = torch.randperm(len(self.applied_effects.keys()))[
115
+ :num_effects
116
+ ]
117
+ effects_to_apply = [
118
+ list(self.applied_effects.keys())[i] for i in effect_indices
119
+ ]
120
+ labels = []
121
+ for effect_name in effects_to_apply:
122
+ effect = self.applied_effects[effect_name]
123
+ dry = effect(dry)
124
+ labels.append(ALL_EFFECTS.index(type(effect)))
125
+
126
+ # Apply effect_to_remove
127
+ effect = self.effect_to_remove[self.effect_to_remove_name]
128
+ wet = effect(torch.clone(dry))
129
+ labels.append(ALL_EFFECTS.index(type(effect)))
130
+
131
+ # Convert labels to one-hot
132
+ one_hot = F.one_hot(torch.tensor(labels), num_classes=len(ALL_EFFECTS))
133
+ effects_present = torch.sum(one_hot, dim=0).float()
134
+
135
+ # Normalize
136
+ normalized_dry = self.normalize(dry)
137
+ normalized_wet = self.normalize(wet)
138
+ return normalized_dry, normalized_wet, effects_present
139
+
140
 
141
  class VocalSetDatamodule(pl.LightningDataModule):
142
  def __init__(
remfx/effects.py CHANGED
@@ -675,7 +675,7 @@ class RandomAudioEffectsChannel(torch.nn.Module):
675
  p=compressor_prob,
676
  ),
677
  RandomApply(
678
- [RandomPebalboardReverb(sample_rate)],
679
  p=reverb_prob,
680
  ),
681
  RandomApply(
@@ -696,3 +696,14 @@ class RandomAudioEffectsChannel(torch.nn.Module):
696
 
697
  def forward(self, x: torch.Tensor):
698
  return self.transforms(x)
 
 
 
 
 
 
 
 
 
 
 
 
675
  p=compressor_prob,
676
  ),
677
  RandomApply(
678
+ [RandomPedalboardReverb(sample_rate)],
679
  p=reverb_prob,
680
  ),
681
  RandomApply(
 
696
 
697
  def forward(self, x: torch.Tensor):
698
  return self.transforms(x)
699
+
700
+
701
+ Pedalboard_Effects = [
702
+ RandomPedalboardReverb,
703
+ RandomPedalboardChorus,
704
+ RandomPedalboardDelay,
705
+ RandomPedalboardDistortion,
706
+ RandomPedalboardCompressor,
707
+ RandomPedalboardPhaser,
708
+ RandomPedalboardLimiter,
709
+ ]
remfx/models.py CHANGED
@@ -64,7 +64,6 @@ class RemFXModel(pl.LightningModule):
64
  optimizer_idx,
65
  optimizer_closure,
66
  on_tpu,
67
- using_native_amp,
68
  using_lbfgs,
69
  ):
70
  # update params
 
64
  optimizer_idx,
65
  optimizer_closure,
66
  on_tpu,
 
67
  using_lbfgs,
68
  ):
69
  # update params
setup.py CHANGED
@@ -30,7 +30,7 @@ setup(
30
  packages=find_packages(),
31
  install_requires=[
32
  "torch>=1.11.0",
33
- "torchaudio",
34
  "functorch",
35
  "scipy",
36
  "numpy",
 
30
  packages=find_packages(),
31
  install_requires=[
32
  "torch>=1.11.0",
33
+ "torchaudio>=0.13.0",
34
  "functorch",
35
  "scipy",
36
  "numpy",