mattricesound commited on
Commit
c1b80c0
1 Parent(s): af0842b

Update CSV logger

Browse files
README.md CHANGED
@@ -1,7 +1,7 @@
1
  # General Purpose Audio Effect Removal
2
  Removing multiple audio effects from multiple sources using compositional audio effect removal and source separation and speech enhancement models.
3
 
4
- This repo contains the code for the paper [General Purpose Audio Effect Removal](https://arxiv.org/abs/2110.00484). (Todo: Link broken, Add video, Add img)
5
 
6
 
7
 
@@ -9,7 +9,7 @@ This repo contains the code for the paper [General Purpose Audio Effect Removal]
9
  ```
10
  git clone https://github.com/mhrice/RemFx.git
11
  git submodule update --init --recursive
12
- pip install . umx
13
  ```
14
  # Usage
15
  This repo can be used for many different tasks. Here are some examples.
@@ -24,11 +24,11 @@ wget https://zenodo.org/record/8183649/files/RemFX_eval_dataset.zip?download=1 -
24
  unzip RemFX_eval_dataset.zip
25
  ```
26
 
27
- ## Download the datasets used in the paper
28
  ```
29
  python scripts/download.py vocalset guitarset idmt-smt-bass idmt-smt-drums
30
  ```
31
- By default, the datasets are downloaded to `./data/remfx-data`. To change this, pass `--output_dir={path/to/datasets}` to `download.py`
32
 
33
  Then set the dataset root :
34
  ```
@@ -36,7 +36,7 @@ export DATASET_ROOT={path/to/datasets}
36
  ```
37
 
38
  ## Training
39
- Before training, it is important that you have downloaded the datasets (see above) and set DATASET_ROOT.
40
  This project uses the [pytorch-lightning](https://www.pytorchlightning.ai/index.html) framework and [hydra](https://hydra.cc/) for configuration management. All experiments are defined in `cfg/exp/`. To train with an existing experiment run
41
  ```
42
  python scripts/train.py +exp={experiment_name}
@@ -55,13 +55,17 @@ Here are some selected experiment types from the paper, which use different data
55
  To change the configuration, simply edit the experiment file, or override the configuration on the command line. A description of some of these variables is in the Misc. section below.
56
  You can also create a custom experiment by creating a new experiment file in `cfg/exp/` and overriding the default parameters in `config.yaml`.
57
 
58
- At the end of training, the train script will automatically evaluate the test set using the best checkpoint (by validation loss). To evaluate a specific checkpoint, run
59
 
60
  ```
61
  python test.py +exp={experiment_name} ckpt_path={path/to/checkpoint}
62
  ```
63
 
64
- If you have generated the dataset separately from training, be sure to set `render_files=False` in the config or command-line, and set `render_root={path_to_dataset}` if it is in a custom location.
 
 
 
 
65
 
66
  Also note that the training assumes you have a GPU. To train on CPU, set `accelerator=null` in the config or command-line.
67
 
@@ -86,16 +90,21 @@ Download checkpoints from [here](https://zenodo.org/record/8179396), or see the
86
 
87
 
88
  ## Generate datasets used in the paper
89
- Before generating datasets, it is important that you have downloaded the datasets (see above) and set DATASET_ROOT.
 
 
90
 
91
- To generate one of the datasets used in the paper, it is as simple as running a training job with a particular config. For example, to generate the `chorus` FXAug dataset, which includes files with 5 possible effects, up to 4 kept effects (distortion, reverb, compression, delay), and 1 removed effects (chorus), run
 
92
  ```
93
- python scripts/train.py +exp=chorus_aug
94
  ```
95
 
96
  See the Misc. section below for a description of the parameters.
97
  By default, files are rendered to `{render_root} / processed / {string_of_effects} / {train|val|test}`.
98
 
 
 
99
  ## Evaluate with a custom directory
100
  Assumes directory is structured as
101
  - root
@@ -120,15 +129,16 @@ python scripts/chain_inference.py +exp=chain_inference_custom
120
 
121
  # Misc.
122
  ## Experimental parameters
123
- Some relevant training parameters descriptions
124
  - `num_kept_effects={[min, max]}` range of <b> Kept </b> effects to apply to each file. Inclusive.
125
  - `num_removed_effects={[min, max]}` range of <b> Removed </b> effects to apply to each file. Inclusive.
126
  - `model={model}` architecture to use (see 'Effect Removal Models/Effect Classification Models')
127
- - `effects_to_keep={[effect]}` Effects to apply but not remove (see 'Effects')
128
  - `effects_to_remove={[effect]}` Effects to remove (see 'Effects')
129
  - `accelerator=null/'gpu'` Use GPU (1 device) (default: null)
130
  - `render_files=True/False` Render files. Disable to skip rendering stage (default: True)
131
  - `render_root={path/to/dir}`. Root directory to render files to (default: ./data)
 
132
 
133
  ### Effect Removal Models
134
  - `umx`
 
1
  # General Purpose Audio Effect Removal
2
  Removing multiple audio effects from multiple sources using compositional audio effect removal and source separation and speech enhancement models.
3
 
4
+ This repo contains the code for the paper [General Purpose Audio Effect Removal](https://arxiv.org/abs/2110.00484). (Todo: Link broken, Add video, Add img, citation)
5
 
6
 
7
 
 
9
  ```
10
  git clone https://github.com/mhrice/RemFx.git
11
  git submodule update --init --recursive
12
+ pip install -e . ./umx
13
  ```
14
  # Usage
15
  This repo can be used for many different tasks. Here are some examples.
 
24
  unzip RemFX_eval_dataset.zip
25
  ```
26
 
27
+ ## Download the starter datasets
28
  ```
29
  python scripts/download.py vocalset guitarset idmt-smt-bass idmt-smt-drums
30
  ```
31
+ By default, the starter datasets are downloaded to `./data/remfx-data`. To change this, pass `--output_dir={path/to/datasets}` to `download.py`
32
 
33
  Then set the dataset root :
34
  ```
 
36
  ```
37
 
38
  ## Training
39
+ Before training, it is important that you have downloaded the starter datasets (see above) and set DATASET_ROOT.
40
  This project uses the [pytorch-lightning](https://www.pytorchlightning.ai/index.html) framework and [hydra](https://hydra.cc/) for configuration management. All experiments are defined in `cfg/exp/`. To train with an existing experiment run
41
  ```
42
  python scripts/train.py +exp={experiment_name}
 
55
  To change the configuration, simply edit the experiment file, or override the configuration on the command line. A description of some of these variables is in the Misc. section below.
56
  You can also create a custom experiment by creating a new experiment file in `cfg/exp/` and overriding the default parameters in `config.yaml`.
57
 
58
+ At the end of training, the train script will automatically evaluate the test set using the best checkpoint (by validation loss). If epoch 0 is not finished, it will throw an error. To evaluate a specific checkpoint, run
59
 
60
  ```
61
  python test.py +exp={experiment_name} ckpt_path={path/to/checkpoint}
62
  ```
63
 
64
+ The checkpoints will be saved in `./logs/ckpts/{timestamp}`
65
+ Metrics and hyperparams will be logged in `./lightning_logs/{timestamp}`
66
+
67
+ By default, the dataset needed for the experiment is generated before training.
68
+ If you have generated the dataset separately (see Generate datasets used in the paper), be sure to set `render_files=False` in the config or command-line, and set `render_root={path_to_dataset}` if it is in a custom location.
69
 
70
  Also note that the training assumes you have a GPU. To train on CPU, set `accelerator=null` in the config or command-line.
71
 
 
90
 
91
 
92
  ## Generate datasets used in the paper
93
+ The datasets used in the experiments are customly generated from the starter datasets. In short, for each training/val/testing example, we select a random 5.5s segment from one of the starter datasets and apply a random number of effects to it. The number of effects applied is controlled by the `num_kept_effects` and `num_removed_effects` parameters. The effects applied are controlled by the `effects_to_keep` and `effects_to_remove` parameters.
94
+
95
+ Before generating datasets, it is important that you have downloaded the starter datasets (see above) and set DATASET_ROOT.
96
 
97
+ To generate one of the datasets used in the paper, use of the experiments defined in `cfg/exp/`.
98
+ For example, to generate the `chorus` FXAug dataset, which includes files with 5 possible effects, up to 4 kept effects (distortion, reverb, compression, delay), and 1 removed effects (chorus), run
99
  ```
100
+ python scripts/generate_dataset.py +exp=chorus_aug
101
  ```
102
 
103
  See the Misc. section below for a description of the parameters.
104
  By default, files are rendered to `{render_root} / processed / {string_of_effects} / {train|val|test}`.
105
 
106
+ If training, this process will be done automatically at the start of training. To disable this, set `render_files=False` in the config or command-line, and set `render_root={path_to_dataset}` if it is in a custom location.
107
+
108
  ## Evaluate with a custom directory
109
  Assumes directory is structured as
110
  - root
 
129
 
130
  # Misc.
131
  ## Experimental parameters
132
+ Some relevant dataset/training parameters descriptions
133
  - `num_kept_effects={[min, max]}` range of <b> Kept </b> effects to apply to each file. Inclusive.
134
  - `num_removed_effects={[min, max]}` range of <b> Removed </b> effects to apply to each file. Inclusive.
135
  - `model={model}` architecture to use (see 'Effect Removal Models/Effect Classification Models')
136
+ - `effects_to_keep={[effect]}` Effects to apply but not remove (see 'Effects'). Used for FXAug.
137
  - `effects_to_remove={[effect]}` Effects to remove (see 'Effects')
138
  - `accelerator=null/'gpu'` Use GPU (1 device) (default: null)
139
  - `render_files=True/False` Render files. Disable to skip rendering stage (default: True)
140
  - `render_root={path/to/dir}`. Root directory to render files to (default: ./data)
141
+ - `datamodule.train_batch_size={batch_size}`. Change batch size (default: varies)
142
 
143
  ### Effect Removal Models
144
  - `umx`
cfg/config.yaml CHANGED
@@ -63,7 +63,7 @@ datamodule:
63
  shuffle_removed_effects: ${shuffle_removed_effects}
64
  render_files: ${render_files}
65
  render_root: ${render_root}
66
- parallel: True
67
  val_dataset:
68
  _target_: remfx.datasets.EffectDataset
69
  total_chunks: 1000
@@ -80,6 +80,7 @@ datamodule:
80
  shuffle_removed_effects: ${shuffle_removed_effects}
81
  render_files: ${render_files}
82
  render_root: ${render_root}
 
83
  test_dataset:
84
  _target_: remfx.datasets.EffectDataset
85
  total_chunks: 1000
@@ -96,6 +97,7 @@ datamodule:
96
  shuffle_removed_effects: ${shuffle_removed_effects}
97
  render_files: ${render_files}
98
  render_root: ${render_root}
 
99
 
100
  train_batch_size: 16
101
  test_batch_size: 1
@@ -115,6 +117,7 @@ datamodule:
115
  logger:
116
  _target_: pytorch_lightning.loggers.CSVLogger
117
  save_dir: "."
 
118
 
119
  trainer:
120
  _target_: pytorch_lightning.Trainer
 
63
  shuffle_removed_effects: ${shuffle_removed_effects}
64
  render_files: ${render_files}
65
  render_root: ${render_root}
66
+ parallel: False
67
  val_dataset:
68
  _target_: remfx.datasets.EffectDataset
69
  total_chunks: 1000
 
80
  shuffle_removed_effects: ${shuffle_removed_effects}
81
  render_files: ${render_files}
82
  render_root: ${render_root}
83
+ parallel: False
84
  test_dataset:
85
  _target_: remfx.datasets.EffectDataset
86
  total_chunks: 1000
 
97
  shuffle_removed_effects: ${shuffle_removed_effects}
98
  render_files: ${render_files}
99
  render_root: ${render_root}
100
+ parallel: False
101
 
102
  train_batch_size: 16
103
  test_batch_size: 1
 
117
  logger:
118
  _target_: pytorch_lightning.loggers.CSVLogger
119
  save_dir: "."
120
+ version: ${now:%Y-%m-%d-%H-%M-%S}
121
 
122
  trainer:
123
  _target_: pytorch_lightning.Trainer
cfg/exp/chain_inference_aug_classifier.yaml CHANGED
@@ -76,7 +76,6 @@ ckpts:
76
  RandomPedalboardDelay:
77
  model: ${dcunet}
78
  ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
79
-
80
  inference_effects_ordering:
81
  - "RandomPedalboardDistortion"
82
  - "RandomPedalboardCompressor"
 
76
  RandomPedalboardDelay:
77
  model: ${dcunet}
78
  ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
 
79
  inference_effects_ordering:
80
  - "RandomPedalboardDistortion"
81
  - "RandomPedalboardCompressor"
remfx/callbacks.py CHANGED
@@ -42,9 +42,7 @@ class AudioCallback(Callback):
42
  )
43
  self.log_train_audio = False
44
 
45
- def on_validation_batch_start(
46
- self, trainer, pl_module, batch, batch_idx, dataloader_idx
47
- ):
48
  x, target, _, rem_fx_labels = batch
49
  # Only run on first batch
50
  if batch_idx == 0 and self.log_audio:
 
42
  )
43
  self.log_train_audio = False
44
 
45
+ def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
 
 
46
  x, target, _, rem_fx_labels = batch
47
  # Only run on first batch
48
  if batch_idx == 0 and self.log_audio:
remfx/datasets.py CHANGED
@@ -83,7 +83,7 @@ def locate_files(root: str, mode: str):
83
  print(f"Found {len(files)} files in GuitarSet {mode}.")
84
  file_list.append(sorted(files))
85
  # ------------------------- DSD100 ---------------------------------
86
- dsd_100_dir = os.path.join(root, "DSD100")
87
  if os.path.isdir(dsd_100_dir):
88
  files = glob.glob(
89
  os.path.join(dsd_100_dir, mode, "**", "*.wav"),
@@ -427,7 +427,13 @@ class EffectDataset(Dataset):
427
  chunk = None
428
  random_dataset_choice = random.choice(self.files)
429
  while chunk is None:
430
- random_file_choice = random.choice(random_dataset_choice)
 
 
 
 
 
 
431
  chunk = select_random_chunk(
432
  random_file_choice, self.chunk_size, self.sample_rate
433
  )
@@ -572,7 +578,7 @@ class EffectDataset(Dataset):
572
  normalized_wet = self.normalize(wet)
573
 
574
  # Check STFT, pick different effects if necessary
575
- stft = self.mrstft(normalized_wet, normalized_dry)
576
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
577
 
578
 
 
83
  print(f"Found {len(files)} files in GuitarSet {mode}.")
84
  file_list.append(sorted(files))
85
  # ------------------------- DSD100 ---------------------------------
86
+ dsd_100_dir = os.path.join(root, "DSD100/DSD100")
87
  if os.path.isdir(dsd_100_dir):
88
  files = glob.glob(
89
  os.path.join(dsd_100_dir, mode, "**", "*.wav"),
 
427
  chunk = None
428
  random_dataset_choice = random.choice(self.files)
429
  while chunk is None:
430
+ try:
431
+ random_file_choice = random.choice(random_dataset_choice)
432
+ except IndexError:
433
+ print("IndexError")
434
+ print(random_dataset_choice)
435
+ print(random_file_choice)
436
+ raise IndexError
437
  chunk = select_random_chunk(
438
  random_file_choice, self.chunk_size, self.sample_rate
439
  )
 
578
  normalized_wet = self.normalize(wet)
579
 
580
  # Check STFT, pick different effects if necessary
581
+ stft = self.mrstft(normalized_wet.unsqueeze(0), normalized_dry.unsqueeze(0))
582
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
583
 
584
 
remfx/models.py CHANGED
@@ -4,7 +4,6 @@ import torchmetrics
4
  import pytorch_lightning as pl
5
  from torch import Tensor, nn
6
  from torchaudio.models import HDemucs
7
- from audio_diffusion_pytorch import DiffusionModel
8
  from auraloss.time import SISDRLoss
9
  from auraloss.freq import MultiResolutionSTFTLoss
10
  from umx.openunmix.model import OpenUnmix, Separator
@@ -343,21 +342,6 @@ class DemucsModel(nn.Module):
343
  return self.model(x).squeeze(1)
344
 
345
 
346
- class DiffusionGenerationModel(nn.Module):
347
- def __init__(self, n_channels: int = 1):
348
- super().__init__()
349
- self.model = DiffusionModel(in_channels=n_channels)
350
-
351
- def forward(self, batch):
352
- x, target = batch
353
- sampled_out = self.model.sample(x)
354
- return self.model(x), sampled_out
355
-
356
- def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
357
- noise = torch.randn(x.shape).to(x)
358
- return self.model.sample(noise, num_steps=num_steps)
359
-
360
-
361
  class DPTNetModel(nn.Module):
362
  def __init__(self, sample_rate, num_bins, **kwargs):
363
  super().__init__()
 
4
  import pytorch_lightning as pl
5
  from torch import Tensor, nn
6
  from torchaudio.models import HDemucs
 
7
  from auraloss.time import SISDRLoss
8
  from auraloss.freq import MultiResolutionSTFTLoss
9
  from umx.openunmix.model import OpenUnmix, Separator
 
342
  return self.model(x).squeeze(1)
343
 
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  class DPTNetModel(nn.Module):
346
  def __init__(self, sample_rate, num_bins, **kwargs):
347
  super().__init__()
remfx/utils.py CHANGED
@@ -3,7 +3,6 @@ from typing import List, Tuple
3
  import pytorch_lightning as pl
4
  from omegaconf import DictConfig
5
  from pytorch_lightning.utilities import rank_zero_only
6
- from frechet_audio_distance import FrechetAudioDistance
7
  import numpy as np
8
  import torch
9
  import torchaudio
@@ -52,9 +51,6 @@ def log_hyperparameters(
52
  if not trainer.logger:
53
  return
54
 
55
- if type(trainer.logger) == pl.loggers.CSVLogger:
56
- return
57
-
58
  hparams = {}
59
 
60
  # choose which parts of hydra config will be saved to loggers
@@ -77,38 +73,10 @@ def log_hyperparameters(
77
  if "callbacks" in config:
78
  hparams["callbacks"] = config["callbacks"]
79
 
80
- logger.experiment.config.update(hparams)
81
-
82
-
83
- class FADLoss(torch.nn.Module):
84
- def __init__(self, sample_rate: float):
85
- super().__init__()
86
- self.fad = FrechetAudioDistance(
87
- use_pca=False, use_activation=False, verbose=False
88
- )
89
- self.fad.model = self.fad.model.to("cpu")
90
- self.sr = sample_rate
91
-
92
- def forward(self, audio_background, audio_eval):
93
- embds_background = []
94
- embds_eval = []
95
- for sample in audio_background:
96
- embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
97
- embds_background.append(embd.cpu().detach().numpy())
98
- for sample in audio_eval:
99
- embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
100
- embds_eval.append(embd.cpu().detach().numpy())
101
- embds_background = np.concatenate(embds_background, axis=0)
102
- embds_eval = np.concatenate(embds_eval, axis=0)
103
- mu_background, sigma_background = self.fad.calculate_embd_statistics(
104
- embds_background
105
- )
106
- mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
107
-
108
- fad_score = self.fad.calculate_frechet_distance(
109
- mu_background, sigma_background, mu_eval, sigma_eval
110
- )
111
- return fad_score
112
 
113
 
114
  def create_random_chunks(
 
3
  import pytorch_lightning as pl
4
  from omegaconf import DictConfig
5
  from pytorch_lightning.utilities import rank_zero_only
 
6
  import numpy as np
7
  import torch
8
  import torchaudio
 
51
  if not trainer.logger:
52
  return
53
 
 
 
 
54
  hparams = {}
55
 
56
  # choose which parts of hydra config will be saved to loggers
 
73
  if "callbacks" in config:
74
  hparams["callbacks"] = config["callbacks"]
75
 
76
+ if type(trainer.logger) == pl.loggers.CSVLogger:
77
+ logger.log_hyperparams(hparams)
78
+ else:
79
+ logger.experiment.config.update(hparams)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  def create_random_chunks(
scripts/download.py CHANGED
@@ -6,54 +6,62 @@ import shutil
6
  def download_zip_dataset(dataset_url: str, output_dir: str):
7
  zip_filename = os.path.basename(dataset_url)
8
  zip_name = zip_filename.replace(".zip", "")
9
- os.system(f"wget -P {output_dir} {dataset_url}")
10
- os.system(
11
- f"""unzip {os.path.join(output_dir, zip_filename)} -d {os.path.join(output_dir, zip_name)}"""
12
- )
13
- os.system(f"rm {os.path.join(output_dir, zip_filename)}")
 
 
 
 
 
14
 
15
 
16
  def process_dataset(dataset_dir: str, output_dir: str):
17
- if dataset_dir == "VocalSet1-2":
18
- pass
19
- elif dataset_dir == "audio_mono-mic":
20
  pass
21
- elif dataset_dir == "IDMT-SMT-BASS":
22
  pass
23
- elif dataset_dir == "IDMT-SMT-DRUMS-V2":
24
  pass
25
- elif dataset_dir == "DSD100":
26
- shutil.rmtree(os.path.join(output_dir, dataset_dir, "Mixtures"))
27
- for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Dev")):
28
- source = os.path.join(output_dir, dataset_dir, "Sources", "Dev", dir)
29
- shutil.move(source, os.path.join(output_dir, dataset_dir))
30
- shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Dev"))
31
- for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Test")):
32
- source = os.path.join(output_dir, dataset_dir, "Sources", "Test", dir)
33
- shutil.move(source, os.path.join(output_dir, dataset_dir))
34
- shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Test"))
35
- shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources"))
36
 
37
- os.mkdir(os.path.join(output_dir, dataset_dir, "train"))
38
- os.mkdir(os.path.join(output_dir, dataset_dir, "val"))
39
- os.mkdir(os.path.join(output_dir, dataset_dir, "test"))
40
- files = os.listdir(os.path.join(output_dir, dataset_dir))
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
42
  num = 0
43
  for dir in files:
44
- if not os.path.isdir(os.path.join(output_dir, dataset_dir, dir)):
45
  continue
46
  if dir == "train" or dir == "val" or dir == "test":
47
  continue
48
- source = os.path.join(output_dir, dataset_dir, dir, "bass.wav")
49
  if num < 80:
50
- dest = os.path.join(output_dir, dataset_dir, "train", f"{num}.wav")
51
  elif num < 90:
52
- dest = os.path.join(output_dir, dataset_dir, "val", f"{num}.wav")
53
  else:
54
- dest = os.path.join(output_dir, dataset_dir, "test", f"{num}.wav")
55
  shutil.move(source, dest)
56
- shutil.rmtree(os.path.join(output_dir, dataset_dir, dir))
57
  num += 1
58
 
59
  else:
@@ -81,11 +89,12 @@ if __name__ == "__main__":
81
  dataset_urls = {
82
  "vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
83
  "guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
84
- "DSD100": "http://liutkus.net/DSD100.zip",
85
- "IDMT-SMT-DRUMS-V2": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
86
  }
87
 
88
  for dataset_name, dataset_url in dataset_urls.items():
89
  if dataset_name in args.dataset_names:
 
90
  download_zip_dataset(dataset_url, args.output_dir)
91
- process_dataset(dataset_name, args.ou)
 
6
  def download_zip_dataset(dataset_url: str, output_dir: str):
7
  zip_filename = os.path.basename(dataset_url)
8
  zip_name = zip_filename.replace(".zip", "")
9
+ if not os.path.exists(os.path.join(output_dir, zip_name)):
10
+ os.system(f"wget -P {output_dir} {dataset_url}")
11
+ os.system(
12
+ f"""unzip {os.path.join(output_dir, zip_filename)} -d {os.path.join(output_dir, zip_name)}"""
13
+ )
14
+ os.system(f"rm {os.path.join(output_dir, zip_filename)}")
15
+ else:
16
+ print(
17
+ f"Dataset {zip_name} already downloaded at {output_dir}, skipping download."
18
+ )
19
 
20
 
21
  def process_dataset(dataset_dir: str, output_dir: str):
22
+ if dataset_dir == "vocalset":
 
 
23
  pass
24
+ elif dataset_dir == "guitarset":
25
  pass
26
+ elif dataset_dir == "idmt-smt-drums":
27
  pass
28
+ elif dataset_dir == "dsd100":
29
+ dataset_root_dir = "DSD100/DSD100"
 
 
 
 
 
 
 
 
 
30
 
31
+ shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Mixtures"))
32
+ for dir in os.listdir(
33
+ os.path.join(output_dir, dataset_root_dir, "Sources", "Dev")
34
+ ):
35
+ source = os.path.join(output_dir, dataset_root_dir, "Sources", "Dev", dir)
36
+ shutil.move(source, os.path.join(output_dir, dataset_root_dir))
37
+ shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Sources", "Dev"))
38
+ for dir in os.listdir(
39
+ os.path.join(output_dir, dataset_root_dir, "Sources", "Test")
40
+ ):
41
+ source = os.path.join(output_dir, dataset_root_dir, "Sources", "Test", dir)
42
+ shutil.move(source, os.path.join(output_dir, dataset_root_dir))
43
+ shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Sources", "Test"))
44
+ shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Sources"))
45
 
46
+ os.mkdir(os.path.join(output_dir, dataset_root_dir, "train"))
47
+ os.mkdir(os.path.join(output_dir, dataset_root_dir, "val"))
48
+ os.mkdir(os.path.join(output_dir, dataset_root_dir, "test"))
49
+ files = os.listdir(os.path.join(output_dir, dataset_root_dir))
50
  num = 0
51
  for dir in files:
52
+ if not os.path.isdir(os.path.join(output_dir, dataset_root_dir, dir)):
53
  continue
54
  if dir == "train" or dir == "val" or dir == "test":
55
  continue
56
+ source = os.path.join(output_dir, dataset_root_dir, dir, "bass.wav")
57
  if num < 80:
58
+ dest = os.path.join(output_dir, dataset_root_dir, "train", f"{num}.wav")
59
  elif num < 90:
60
+ dest = os.path.join(output_dir, dataset_root_dir, "val", f"{num}.wav")
61
  else:
62
+ dest = os.path.join(output_dir, dataset_root_dir, "test", f"{num}.wav")
63
  shutil.move(source, dest)
64
+ shutil.rmtree(os.path.join(output_dir, dataset_root_dir, dir))
65
  num += 1
66
 
67
  else:
 
89
  dataset_urls = {
90
  "vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
91
  "guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
92
+ "dsd100": "http://liutkus.net/DSD100.zip",
93
+ "idmt-smt-drums": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
94
  }
95
 
96
  for dataset_name, dataset_url in dataset_urls.items():
97
  if dataset_name in args.dataset_names:
98
+ print("Downloading dataset: ", dataset_name)
99
  download_zip_dataset(dataset_url, args.output_dir)
100
+ process_dataset(dataset_name, args.output_dir)
scripts/generate_dataset.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import hydra
3
+ from omegaconf import DictConfig
4
+
5
+
6
+ @hydra.main(version_base=None, config_path="../cfg", config_name="config.yaml")
7
+ def main(cfg: DictConfig):
8
+ # Apply seed for reproducibility
9
+ if cfg.seed:
10
+ pl.seed_everything(cfg.seed)
11
+ datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
12
+
13
+
14
+ if __name__ == "__main__":
15
+ main()
setup.py CHANGED
@@ -35,18 +35,14 @@ setup(
35
  "scipy",
36
  "numpy",
37
  "torchvision",
38
- "pytorch-lightning",
39
  "numba",
40
  "wandb",
41
- "audio-diffusion-pytorch",
42
- "ema_pytorch",
43
  "einops",
44
- "librosa",
45
  "hydra-core",
46
  "auraloss",
47
  "pyloudnorm",
48
  "pedalboard",
49
- "frechet_audio_distance",
50
  "asteroid",
51
  ],
52
  include_package_data=True,
 
35
  "scipy",
36
  "numpy",
37
  "torchvision",
38
+ "pytorch-lightning>=2.0.0",
39
  "numba",
40
  "wandb",
 
 
41
  "einops",
 
42
  "hydra-core",
43
  "auraloss",
44
  "pyloudnorm",
45
  "pedalboard",
 
46
  "asteroid",
47
  ],
48
  include_package_data=True,