mattricesound commited on
Commit
8cb3861
1 Parent(s): fe64756

Refactor commit, see README for details

Browse files
.gitignore CHANGED
@@ -8,4 +8,5 @@ __pycache__/
8
  lightning_logs/
9
  outputs/
10
  logs/
11
- .vscode/
 
 
8
  lightning_logs/
9
  outputs/
10
  logs/
11
+ .vscode/
12
+ ckpts/
README.md CHANGED
@@ -14,32 +14,35 @@
14
 
15
  ## Train model
16
  1. Change Wandb and data root variables in `shell_vars.sh` and `source shell_vars.sh`
17
- 2. `python scripts/train.py +exp=umx_distortion`
18
- or
19
- 2. `python scripts/train.py +exp=demucs_distortion`
20
- See cfg for more options. Generally they are `+exp={model}_{effect}`
21
- Models and effects detailed below.
22
 
23
- To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
24
 
25
- Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices=1`
26
-
27
- ### Current Models
28
  - `umx`
29
  - `demucs`
30
 
31
- ### Current Effects
32
  - `chorus`
33
  - `compressor`
34
  - `distortion`
35
  - `reverb`
36
- - `all` (choose random effect to apply to each file)
37
 
38
- ### Testing
39
- Experiment dictates data, ckpt dictates model
40
- `python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  ## Misc.
43
  By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
44
- To skip rendering files (use previously rendered), add `render_files=False` to the command-line (added to test by default).
45
- To change the rendered location, add `render_root={path/to/dir}` to the command-line (use this for train and test)
 
14
 
15
  ## Train model
16
  1. Change Wandb and data root variables in `shell_vars.sh` and `source shell_vars.sh`
17
+ 2. `python scripts/train.py model=demucs "effects_to_remove=[distortion]"`
 
 
 
 
18
 
 
19
 
20
+ ## Models
 
 
21
  - `umx`
22
  - `demucs`
23
 
24
+ ## Effects
25
  - `chorus`
26
  - `compressor`
27
  - `distortion`
28
  - `reverb`
 
29
 
30
+ ## Train CLI Options
31
+ - `max_kept_effects={n}` max number of <b> Kept </b> effects to apply to each file (default: 3)
32
+ - `model={model}` architecture to use (see 'Models')
33
+ - `shuffle_kept_effects=True/False` Shuffle kept effects (default: True)
34
+ - `shuffle_removed_effects=True/False` Shuffle Removed effects (default: False)
35
+ - `effects_to_use={effect}` Effects to use (see 'Effects') (default: all in the list)
36
+ - `effects_to_remove={effect}` Effects to remove (see 'Effects') (default: all in the list)
37
+ - `trainer.accelerator='gpu'` : Use GPU (default: None)
38
+ - `trainer.devices={n}` Number of GPUs to use (default: 1)
39
+ - `render_files=True/False` Render files. Disable to skip rendering stage (default: True)
40
+ - `render_root={path/to/dir}`. Root directory to render files to (default: DATASET_ROOT)
41
+
42
+ Example: `python scripts/train.py model=demucs "effects_to_use=[distortion, reverb]" "effects_to_remove=[distortion]" "max_kept_effects=2" "shuffle_kept_effects=False" "shuffle_removed_effects=True" trainer.accelerator='gpu' trainer.devices=2`
43
+
44
 
45
  ## Misc.
46
  By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
47
+
48
+
cfg/applied_effects/all.yaml DELETED
@@ -1,31 +0,0 @@
1
- # @package _global_
2
- applied_effects:
3
- Chorus:
4
- _target_: remfx.effects.RandomPedalboardChorus
5
- sample_rate: ${sample_rate}
6
- min_depth: 0.2
7
- min_mix: 0.3
8
- Distortion:
9
- _target_: remfx.effects.RandomPedalboardDistortion
10
- sample_rate: ${sample_rate}
11
- min_drive_db: 10
12
- max_drive_db: 50
13
- Compressor:
14
- _target_: remfx.effects.RandomPedalboardCompressor
15
- sample_rate: ${sample_rate}
16
- min_threshold_db: -42.0
17
- max_threshold_db: -20.0
18
- min_ratio: 1.5
19
- max_ratio: 6.0
20
- Reverb:
21
- _target_: remfx.effects.RandomPedalboardReverb
22
- sample_rate: ${sample_rate}
23
- min_room_size: 0.3
24
- max_room_size: 1.0
25
- min_damping: 0.2
26
- max_damping: 1.0
27
- min_wet_dry: 0.2
28
- max_wet_dry: 0.8
29
- min_width: 0.2
30
- max_width: 1.0
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/config.yaml CHANGED
@@ -1,10 +1,9 @@
1
  defaults:
2
  - _self_
3
  - model: null
4
- - applied_effects: null
5
- - effect_to_remove: null
6
 
7
- max_effects_per_file: 3
8
  seed: 12345
9
  train: True
10
  sample_rate: 48000
@@ -13,6 +12,20 @@ logs_dir: "./logs"
13
  render_files: True
14
  render_root: "./data"
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  callbacks:
17
  model_checkpoint:
18
  _target_: pytorch_lightning.callbacks.ModelCheckpoint
@@ -35,9 +48,12 @@ datamodule:
35
  root: ${oc.env:DATASET_ROOT}
36
  chunk_size: ${chunk_size}
37
  mode: "train"
38
- applied_effects: ${applied_effects}
39
- effect_to_remove: ${effect_to_remove}
40
- max_effects_per_file: ${max_effects_per_file}
 
 
 
41
  render_files: ${render_files}
42
  render_root: ${render_root}
43
  val_dataset:
@@ -46,9 +62,12 @@ datamodule:
46
  root: ${oc.env:DATASET_ROOT}
47
  chunk_size: ${chunk_size}
48
  mode: "val"
49
- applied_effects: ${applied_effects}
50
- effect_to_remove: ${effect_to_remove}
51
- max_effects_per_file: ${max_effects_per_file}
 
 
 
52
  render_files: ${render_files}
53
  render_root: ${render_root}
54
  test_dataset:
@@ -57,9 +76,12 @@ datamodule:
57
  root: ${oc.env:DATASET_ROOT}
58
  chunk_size: ${chunk_size}
59
  mode: "test"
60
- applied_effects: ${applied_effects}
61
- effect_to_remove: ${effect_to_remove}
62
- max_effects_per_file: ${max_effects_per_file}
 
 
 
63
  render_files: ${render_files}
64
  render_root: ${render_root}
65
 
@@ -89,3 +111,4 @@ trainer:
89
  devices: 1
90
  gradient_clip_val: 10.0
91
  max_steps: 50000
 
 
1
  defaults:
2
  - _self_
3
  - model: null
4
+ - effects: all
5
+
6
 
 
7
  seed: 12345
8
  train: True
9
  sample_rate: 48000
 
12
  render_files: True
13
  render_root: "./data"
14
 
15
+ max_kept_effects: 3
16
+ shuffle_kept_effects: True
17
+ shuffle_removed_effects: False
18
+ effects_to_use:
19
+ - compressor
20
+ - distortion
21
+ - reverb
22
+ - chorus
23
+ effects_to_remove:
24
+ - compressor
25
+ - distortion
26
+ - reverb
27
+ - chorus
28
+
29
  callbacks:
30
  model_checkpoint:
31
  _target_: pytorch_lightning.callbacks.ModelCheckpoint
 
48
  root: ${oc.env:DATASET_ROOT}
49
  chunk_size: ${chunk_size}
50
  mode: "train"
51
+ effect_modules: ${effects}
52
+ effects_to_use: ${effects_to_use}
53
+ effects_to_remove: ${effects_to_remove}
54
+ max_kept_effects: ${max_kept_effects}
55
+ shuffle_kept_effects: ${shuffle_kept_effects}
56
+ shuffle_removed_effects: ${shuffle_removed_effects}
57
  render_files: ${render_files}
58
  render_root: ${render_root}
59
  val_dataset:
 
62
  root: ${oc.env:DATASET_ROOT}
63
  chunk_size: ${chunk_size}
64
  mode: "val"
65
+ effect_modules: ${effects}
66
+ effects_to_use: ${effects_to_use}
67
+ effects_to_remove: ${effects_to_remove}
68
+ max_kept_effects: ${max_kept_effects}
69
+ shuffle_kept_effects: ${shuffle_kept_effects}
70
+ shuffle_removed_effects: ${shuffle_removed_effects}
71
  render_files: ${render_files}
72
  render_root: ${render_root}
73
  test_dataset:
 
76
  root: ${oc.env:DATASET_ROOT}
77
  chunk_size: ${chunk_size}
78
  mode: "test"
79
+ effect_modules: ${effects}
80
+ effects_to_use: ${effects_to_use}
81
+ effects_to_remove: ${effects_to_remove}
82
+ max_kept_effects: ${max_kept_effects}
83
+ shuffle_kept_effects: ${shuffle_kept_effects}
84
+ shuffle_removed_effects: ${shuffle_removed_effects}
85
  render_files: ${render_files}
86
  render_root: ${render_root}
87
 
 
111
  devices: 1
112
  gradient_clip_val: 10.0
113
  max_steps: 50000
114
+
cfg/effect_to_remove/all.yaml DELETED
@@ -1,31 +0,0 @@
1
- # @package _global_
2
- effect_to_remove:
3
- Chorus:
4
- _target_: remfx.effects.RandomPedalboardChorus
5
- sample_rate: ${sample_rate}
6
- min_depth: 0.2
7
- min_mix: 0.3
8
- Distortion:
9
- _target_: remfx.effects.RandomPedalboardDistortion
10
- sample_rate: ${sample_rate}
11
- min_drive_db: 10
12
- max_drive_db: 50
13
- Compressor:
14
- _target_: remfx.effects.RandomPedalboardCompressor
15
- sample_rate: ${sample_rate}
16
- min_threshold_db: -42.0
17
- max_threshold_db: -20.0
18
- min_ratio: 1.5
19
- max_ratio: 6.0
20
- Reverb:
21
- _target_: remfx.effects.RandomPedalboardReverb
22
- sample_rate: ${sample_rate}
23
- min_room_size: 0.3
24
- max_room_size: 1.0
25
- min_damping: 0.2
26
- max_damping: 1.0
27
- min_wet_dry: 0.2
28
- max_wet_dry: 0.8
29
- min_width: 0.2
30
- max_width: 1.0
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/effect_to_remove/chorus.yaml DELETED
@@ -1,7 +0,0 @@
1
- # @package _global_
2
- effect_to_remove:
3
- Chorus:
4
- _target_: remfx.effects.RandomPedalboardChorus
5
- sample_rate: ${sample_rate}
6
- min_depth: 0.2
7
- min_mix: 0.3
 
 
 
 
 
 
 
 
cfg/effect_to_remove/compressor.yaml DELETED
@@ -1,9 +0,0 @@
1
- # @package _global_
2
- effect_to_remove:
3
- Compressor:
4
- _target_: remfx.effects.RandomPedalboardCompressor
5
- sample_rate: ${sample_rate}
6
- min_threshold_db: -42.0
7
- max_threshold_db: -20.0
8
- min_ratio: 1.5
9
- max_ratio: 6.0
 
 
 
 
 
 
 
 
 
 
cfg/effect_to_remove/distortion.yaml DELETED
@@ -1,7 +0,0 @@
1
- # @package _global_
2
- effect_to_remove:
3
- Distortion:
4
- _target_: remfx.effects.RandomPedalboardDistortion
5
- sample_rate: ${sample_rate}
6
- min_drive_db: 10
7
- max_drive_db: 50
 
 
 
 
 
 
 
 
cfg/effect_to_remove/reverb.yaml DELETED
@@ -1,13 +0,0 @@
1
- # @package _global_
2
- effect_to_remove:
3
- Reverb:
4
- _target_: remfx.effects.RandomPedalboardReverb
5
- sample_rate: ${sample_rate}
6
- min_room_size: 0.3
7
- max_room_size: 1.0
8
- min_damping: 0.2
9
- max_damping: 1.0
10
- min_wet_dry: 0.2
11
- max_wet_dry: 0.8
12
- min_width: 0.2
13
- max_width: 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfg/effects/all.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ effects:
4
+ chorus:
5
+ _target_: remfx.effects.RandomPedalboardChorus
6
+ sample_rate: ${sample_rate}
7
+ min_depth: 0.2
8
+ min_mix: 0.3
9
+ distortion:
10
+ _target_: remfx.effects.RandomPedalboardDistortion
11
+ sample_rate: ${sample_rate}
12
+ min_drive_db: 10
13
+ max_drive_db: 50
14
+ compressor:
15
+ _target_: remfx.effects.RandomPedalboardCompressor
16
+ sample_rate: ${sample_rate}
17
+ min_threshold_db: -42.0
18
+ max_threshold_db: -20.0
19
+ min_ratio: 1.5
20
+ max_ratio: 6.0
21
+ reverb:
22
+ _target_: remfx.effects.RandomPedalboardReverb
23
+ sample_rate: ${sample_rate}
24
+ min_room_size: 0.3
25
+ max_room_size: 1.0
26
+ min_damping: 0.2
27
+ max_damping: 1.0
28
+ min_wet_dry: 0.2
29
+ max_wet_dry: 0.8
30
+ min_width: 0.2
31
+ max_width: 1.0
cfg/exp/demucs_all.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: demucs
4
- - override /applied_effects: all
5
- - override /effect_to_remove: all
 
 
 
 
 
 
cfg/exp/demucs_chorus.yaml DELETED
@@ -1,6 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: demucs
4
- - override /applied_effects: all
5
- - override /effect_to_remove: chorus
6
-
 
 
 
 
 
 
 
cfg/exp/demucs_compressor.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: demucs
4
- - override /applied_effects: all
5
- - override /effect_to_remove: distortion
 
 
 
 
 
 
cfg/exp/demucs_distortion.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: demucs
4
- - override /applied_effects: all
5
- - override /effect_to_remove: distortion
 
 
 
 
 
 
cfg/exp/demucs_reverb.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: demucs
4
- - override /applied_effects: all
5
- - override /effect_to_remove: reverb
 
 
 
 
 
 
cfg/exp/umx_all.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: umx
4
- - override /applied_effects: all
5
- - override /effect_to_remove: all
 
 
 
 
 
 
cfg/exp/umx_chorus.yaml DELETED
@@ -1,6 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: umx
4
- - override /applied_effects: all
5
- - override /effect_to_remove: chorus
6
-
 
 
 
 
 
 
 
cfg/exp/umx_compressor.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: umx
4
- - override /applied_effects: all
5
- - override /effect_to_remove: distortion
 
 
 
 
 
 
cfg/exp/umx_distortion.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: umx
4
- - override /applied_effects: all
5
- - override /effect_to_remove: distortion
 
 
 
 
 
 
cfg/exp/umx_reverb.yaml DELETED
@@ -1,5 +0,0 @@
1
- # @package _global_
2
- defaults:
3
- - override /model: umx
4
- - override /applied_effects: all
5
- - override /effect_to_remove: reverb
 
 
 
 
 
 
remfx/datasets.py CHANGED
@@ -5,12 +5,13 @@ import torchaudio
5
  from pathlib import Path
6
  import pytorch_lightning as pl
7
  import sys
8
- from typing import Any, Dict
9
  from remfx import effects
10
  from tqdm import tqdm
11
  from remfx.utils import create_sequential_chunks
12
  import shutil
13
 
 
14
  # https://zenodo.org/record/1193957 -> VocalSet
15
 
16
  ALL_EFFECTS = effects.Pedalboard_Effects
@@ -22,9 +23,12 @@ class VocalSet(Dataset):
22
  root: str,
23
  sample_rate: int,
24
  chunk_size: int = 3,
25
- applied_effects: Dict[str, torch.nn.Module] = None,
26
- effect_to_remove: Dict[str, torch.nn.Module] = None,
27
- max_effects_per_file: int = 1,
 
 
 
28
  render_files: bool = True,
29
  render_root: str = None,
30
  mode: str = "train",
@@ -37,17 +41,19 @@ class VocalSet(Dataset):
37
  self.chunk_size = chunk_size
38
  self.sample_rate = sample_rate
39
  self.mode = mode
40
- self.max_effects_per_file = max_effects_per_file
41
- self.effect_to_remove = effect_to_remove
42
  mode_path = self.root / self.mode
43
  self.files = sorted(list(mode_path.glob("./**/*.wav")))
 
 
 
44
  self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
45
- self.applied_effects = applied_effects
46
- self.effect_to_remove_name = "_".join([e for e in self.effect_to_remove])
 
47
 
48
- effect_str = "__".join([e for e in self.applied_effects])
49
- effect_str += f"_{self.effect_to_remove_name}"
50
- self.proc_root = self.render_root / "processed" / effect_str / self.mode
51
 
52
  if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
53
  print("Found processed files.")
@@ -103,38 +109,66 @@ class VocalSet(Dataset):
103
  target, sr = torchaudio.load(target_file)
104
  return (input, target, effect_name)
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def process_effects(self, dry: torch.Tensor):
107
- # Apply random number of effects up to num_effects - 1 (excluding effect_to_remove)
108
- if self.max_effects_per_file > 1:
109
- num_effects = torch.randint(self.max_effects_per_file - 1, (1,)).item()
110
- # Remove effect to remove from applied effects if present
111
- for effect in self.effect_to_remove:
112
- self.applied_effects.pop(effect, None)
113
-
114
- # Choose random effects to apply
115
- effect_indices = torch.randperm(len(self.applied_effects.keys()))[
116
- :num_effects
117
- ]
118
- effects_to_apply = [
119
- list(self.applied_effects.keys())[i] for i in effect_indices
120
- ]
121
- labels = []
122
- for effect_name in effects_to_apply:
123
- effect = self.applied_effects[effect_name]
124
- dry = effect(dry)
125
- labels.append(ALL_EFFECTS.index(type(effect)))
126
-
127
- # Apply effect_to_remove
128
  wet = torch.clone(dry)
129
- for effect_name in self.effect_to_remove:
130
- effect = self.effect_to_remove[effect_name]
131
- wet = effect(dry)
 
 
 
 
 
132
  labels.append(ALL_EFFECTS.index(type(effect)))
133
 
134
  # Convert labels to one-hot
135
  one_hot = F.one_hot(torch.tensor(labels), num_classes=len(ALL_EFFECTS))
136
  effects_present = torch.sum(one_hot, dim=0).float()
137
-
138
  # Normalize
139
  normalized_dry = self.normalize(dry)
140
  normalized_wet = self.normalize(wet)
 
5
  from pathlib import Path
6
  import pytorch_lightning as pl
7
  import sys
8
+ from typing import Any, List, Dict
9
  from remfx import effects
10
  from tqdm import tqdm
11
  from remfx.utils import create_sequential_chunks
12
  import shutil
13
 
14
+
15
  # https://zenodo.org/record/1193957 -> VocalSet
16
 
17
  ALL_EFFECTS = effects.Pedalboard_Effects
 
23
  root: str,
24
  sample_rate: int,
25
  chunk_size: int = 3,
26
+ effect_modules: List[Dict[str, torch.nn.Module]] = None,
27
+ effects_to_use: List[str] = None,
28
+ effects_to_remove: List[str] = None,
29
+ max_kept_effects: int = 1,
30
+ shuffle_kept_effects: bool = True,
31
+ shuffle_removed_effects: bool = False,
32
  render_files: bool = True,
33
  render_root: str = None,
34
  mode: str = "train",
 
41
  self.chunk_size = chunk_size
42
  self.sample_rate = sample_rate
43
  self.mode = mode
 
 
44
  mode_path = self.root / self.mode
45
  self.files = sorted(list(mode_path.glob("./**/*.wav")))
46
+ self.max_kept_effects = max_kept_effects
47
+ self.effects_to_use = effects_to_use
48
+ self.effects_to_remove = effects_to_remove
49
  self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
50
+ self.effects = effect_modules
51
+ self.shuffle_kept_effects = shuffle_kept_effects
52
+ self.shuffle_removed_effects = shuffle_removed_effects
53
 
54
+ effects_string = "_".join(self.effects_to_use + ["_"] + self.effects_to_remove)
55
+ self.effects_to_keep = self.validate_effect_input()
56
+ self.proc_root = self.render_root / "processed" / effects_string / self.mode
57
 
58
  if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
59
  print("Found processed files.")
 
109
  target, sr = torchaudio.load(target_file)
110
  return (input, target, effect_name)
111
 
112
+ def validate_effect_input(self):
113
+ for effect in self.effects.values():
114
+ if type(effect) not in ALL_EFFECTS:
115
+ raise ValueError(
116
+ f"Effect {effect} not found in ALL_EFFECTS. "
117
+ f"Please choose from {ALL_EFFECTS}"
118
+ )
119
+ for effect in self.effects_to_use:
120
+ if effect not in self.effects.keys():
121
+ raise ValueError(
122
+ f"Effect {effect} not found in self.effects. "
123
+ f"Please choose from {self.effects.keys()}"
124
+ )
125
+ for effect in self.effects_to_remove:
126
+ if effect not in self.effects.keys():
127
+ raise ValueError(
128
+ f"Effect {effect} not found in self.effects. "
129
+ f"Please choose from {self.effects.keys()}"
130
+ )
131
+ kept_fx = list(set(self.effects_to_use) - set(self.effects_to_remove))
132
+ kept_str = "randomly" if self.shuffle_kept_effects else "in order"
133
+ removed_str = "randomly" if self.shuffle_removed_effects else "in order"
134
+ rem_fx = self.effects_to_remove
135
+ print(
136
+ f"Effect Summary: \n"
137
+ f"Apply effects: {kept_fx} (Up to {self.max_kept_effects}, chosen {kept_str}) -> Dry\n"
138
+ f"Apply effects: {rem_fx} (All {len(rem_fx)}, chosen {removed_str}) -> Wet\n"
139
+ )
140
+ return kept_fx
141
+
142
  def process_effects(self, dry: torch.Tensor):
143
+ labels = []
144
+
145
+ # Apply Kept Effects
146
+ if self.shuffle_kept_effects:
147
+ effect_indices = torch.randperm(len(self.effects_to_keep))
148
+ else:
149
+ effect_indices = torch.arange(len(self.effects_to_keep))
150
+ effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
151
+ effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
152
+ for effect in effects_to_apply:
153
+ dry = effect(dry)
154
+ labels.append(ALL_EFFECTS.index(type(effect)))
155
+ print(labels)
156
+
157
+ # Apply effects_to_remove
 
 
 
 
 
 
158
  wet = torch.clone(dry)
159
+ if self.shuffle_removed_effects:
160
+ effect_indices = torch.randperm(len(self.effects_to_remove))
161
+ else:
162
+ effect_indices = torch.arange(len(self.effects_to_remove))
163
+ effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
164
+ effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
165
+ for effect in effects_to_apply:
166
+ wet = effect(wet)
167
  labels.append(ALL_EFFECTS.index(type(effect)))
168
 
169
  # Convert labels to one-hot
170
  one_hot = F.one_hot(torch.tensor(labels), num_classes=len(ALL_EFFECTS))
171
  effects_present = torch.sum(one_hot, dim=0).float()
 
172
  # Normalize
173
  normalized_dry = self.normalize(dry)
174
  normalized_wet = self.normalize(wet)