mattricesound commited on
Commit
7bb4fe3
β€’
1 Parent(s): e0a5f6f

Refactored config

Browse files
config.yaml β†’ cfg/config.yaml RENAMED
@@ -1,6 +1,8 @@
1
  defaults:
2
  - _self_
3
- - exp: null
 
 
4
  seed: 12345
5
  train: True
6
  sample_rate: 48000
@@ -26,12 +28,14 @@ datamodule:
26
  root: ${oc.env:DATASET_ROOT}
27
  chunk_size_in_sec: 6
28
  mode: "train"
 
29
  val_dataset:
30
  _target_: remfx.datasets.VocalSet
31
  sample_rate: ${sample_rate}
32
  root: ${oc.env:DATASET_ROOT}
33
  chunk_size_in_sec: 6
34
  mode: "val"
 
35
  batch_size: 16
36
  num_workers: 8
37
  pin_memory: True
 
1
  defaults:
2
  - _self_
3
+ - model: null
4
+ - effects: null
5
+
6
  seed: 12345
7
  train: True
8
  sample_rate: 48000
 
28
  root: ${oc.env:DATASET_ROOT}
29
  chunk_size_in_sec: 6
30
  mode: "train"
31
+ effect_types: ${effects.train_effects}
32
  val_dataset:
33
  _target_: remfx.datasets.VocalSet
34
  sample_rate: ${sample_rate}
35
  root: ${oc.env:DATASET_ROOT}
36
  chunk_size_in_sec: 6
37
  mode: "val"
38
+ effect_types: ${effects.val_effects}
39
  batch_size: 16
40
  num_workers: 8
41
  pin_memory: True
config_guitarset.yaml β†’ cfg/config_guitarset.yaml RENAMED
File without changes
config_guitfx.yaml β†’ cfg/config_guitfx.yaml RENAMED
File without changes
cfg/effects/distortion.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ effects:
3
+ train_effects:
4
+ Distortion:
5
+ _target_: remfx.effects.RandomPedalboardDistortion
6
+ sample_rate: ${sample_rate}
7
+ min_drive_db: -10
8
+ max_drive_db: 50
9
+ val_effects:
10
+ Distortion:
11
+ _target_: remfx.effects.RandomPedalboardDistortion
12
+ sample_rate: ${sample_rate}
13
+ min_drive_db: 25
14
+ max_drive_db: 25
cfg/exp/demucs_distortion.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: distortion
cfg/exp/umx_distortion.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: umx
4
+ - override /effects: distortion
{exp β†’ cfg/model}/audio_diffusion.yaml RENAMED
File without changes
{exp β†’ cfg/model}/demucs.yaml RENAMED
@@ -13,11 +13,4 @@ model:
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
16
- datamodule:
17
- dataset:
18
- effect_types:
19
- Distortion:
20
- _target_: remfx.effects.RandomPedalboardDistortion
21
- sample_rate: ${sample_rate}
22
- min_drive_db: -10
23
- max_drive_db: 50
 
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
16
+
 
 
 
 
 
 
 
{exp β†’ cfg/model}/umx.yaml RENAMED
@@ -14,18 +14,4 @@ model:
14
  n_channels: 1
15
  alpha: 0.3
16
  sample_rate: ${sample_rate}
17
- datamodule:
18
- train_dataset:
19
- effect_types:
20
- Distortion:
21
- _target_: remfx.effects.RandomPedalboardDistortion
22
- sample_rate: ${sample_rate}
23
- min_drive_db: -10
24
- max_drive_db: 50
25
- val_dataset:
26
- effect_types:
27
- Distortion:
28
- _target_: remfx.effects.RandomPedalboardDistortion
29
- sample_rate: ${sample_rate}
30
- min_drive_db: -10
31
- max_drive_db: 50
 
14
  n_channels: 1
15
  alpha: 0.3
16
  sample_rate: ${sample_rate}
17
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
remfx/datasets.py CHANGED
@@ -190,6 +190,9 @@ class VocalSet(Dataset):
190
  self.chunk_size_in_sec = chunk_size_in_sec
191
  self.sample_rate = sample_rate
192
  self.mode = mode
 
 
 
193
 
194
  mode_path = self.root / self.mode
195
  self.files = sorted(list(mode_path.glob("./**/*.wav")))
@@ -225,18 +228,14 @@ class VocalSet(Dataset):
225
 
226
  # Add random effect if train
227
  if self.mode == "train":
228
- random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
229
- effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
230
- effect = self.effect_types[effect_name]
231
- effected_input = effect(resampled_x)
232
  else:
233
- # deterministic static effect for eval
234
  effect_idx = idx % len(self.effect_types.keys())
235
- effect_name = list(self.effect_types.keys())[effect_idx]
236
- effect = deterministic_effects[effect_name]
237
- effected_input = torch.from_numpy(
238
- effect(resampled_x.numpy(), self.sample_rate)
239
- )
240
  normalized_input = self.normalize(effected_input)
241
  normalized_target = self.normalize(resampled_x)
242
  return (normalized_input, normalized_target, effect_name)
 
190
  self.chunk_size_in_sec = chunk_size_in_sec
191
  self.sample_rate = sample_rate
192
  self.mode = mode
193
+ import pdb
194
+
195
+ pdb.set_trace()
196
 
197
  mode_path = self.root / self.mode
198
  self.files = sorted(list(mode_path.glob("./**/*.wav")))
 
228
 
229
  # Add random effect if train
230
  if self.mode == "train":
231
+ effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
 
 
 
232
  else:
233
+ # deterministic effect for eval
234
  effect_idx = idx % len(self.effect_types.keys())
235
+ effect_name = list(self.effect_types.keys())[int(effect_idx)]
236
+ effect = self.effect_types[effect_name]
237
+ effected_input = effect(resampled_x)
238
+
 
239
  normalized_input = self.normalize(effected_input)
240
  normalized_target = self.normalize(resampled_x)
241
  return (normalized_input, normalized_target, effect_name)
scripts/train.py CHANGED
@@ -7,12 +7,12 @@ from pytorch_lightning.utilities.model_summary import ModelSummary
7
  log = utils.get_logger(__name__)
8
 
9
 
10
- @hydra.main(version_base=None, config_path="../", config_name="config.yaml")
11
  def main(cfg: DictConfig):
12
  # Apply seed for reproducibility
13
  if cfg.seed:
14
  pl.seed_everything(cfg.seed)
15
-
16
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
17
  datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
18
  log.info(f"Instantiating model <{cfg.model._target_}>.")
 
7
  log = utils.get_logger(__name__)
8
 
9
 
10
+ @hydra.main(version_base=None, config_path="../cfg", config_name="config.yaml")
11
  def main(cfg: DictConfig):
12
  # Apply seed for reproducibility
13
  if cfg.seed:
14
  pl.seed_everything(cfg.seed)
15
+ print(cfg)
16
  log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
17
  datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
18
  log.info(f"Instantiating model <{cfg.model._target_}>.")