mattricesound commited on
Commit
bd1743b
1 Parent(s): e0aa67f

Add dsd100 dataset

Browse files
cfg/config.yaml CHANGED
@@ -53,9 +53,9 @@ callbacks:
53
  _target_: remfx.callbacks.MetricCallback
54
 
55
  datamodule:
56
- _target_: remfx.datasets.VocalSetDatamodule
57
  train_dataset:
58
- _target_: remfx.datasets.VocalSet
59
  sample_rate: ${sample_rate}
60
  root: ${oc.env:DATASET_ROOT}
61
  chunk_size: ${chunk_size}
@@ -70,7 +70,7 @@ datamodule:
70
  render_files: ${render_files}
71
  render_root: ${render_root}
72
  val_dataset:
73
- _target_: remfx.datasets.VocalSet
74
  sample_rate: ${sample_rate}
75
  root: ${oc.env:DATASET_ROOT}
76
  chunk_size: ${chunk_size}
@@ -85,7 +85,7 @@ datamodule:
85
  render_files: ${render_files}
86
  render_root: ${render_root}
87
  test_dataset:
88
- _target_: remfx.datasets.VocalSet
89
  sample_rate: ${sample_rate}
90
  root: ${oc.env:DATASET_ROOT}
91
  chunk_size: ${chunk_size}
 
53
  _target_: remfx.callbacks.MetricCallback
54
 
55
  datamodule:
56
+ _target_: remfx.datasets.EffectDatamodule
57
  train_dataset:
58
+ _target_: remfx.datasets.EffectDataset
59
  sample_rate: ${sample_rate}
60
  root: ${oc.env:DATASET_ROOT}
61
  chunk_size: ${chunk_size}
 
70
  render_files: ${render_files}
71
  render_root: ${render_root}
72
  val_dataset:
73
+ _target_: remfx.datasets.EffectDataset
74
  sample_rate: ${sample_rate}
75
  root: ${oc.env:DATASET_ROOT}
76
  chunk_size: ${chunk_size}
 
85
  render_files: ${render_files}
86
  render_root: ${render_root}
87
  test_dataset:
88
+ _target_: remfx.datasets.EffectDataset
89
  sample_rate: ${sample_rate}
90
  root: ${oc.env:DATASET_ROOT}
91
  chunk_size: ${chunk_size}
cfg/exp/default.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
  defaults:
3
- - override /model: demucs
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
 
1
  # @package _global_
2
  defaults:
3
+ - override /model: umx
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
remfx/datasets.py CHANGED
@@ -55,6 +55,11 @@ idmt_bass_splits = {
55
  "val": ["VIF"],
56
  "test": ["VIS"],
57
  }
 
 
 
 
 
58
  idmt_drums_splits = {
59
  "train": ["WaveDrum02", "TechnoDrum01"],
60
  "val": ["RealDrum01"],
@@ -105,19 +110,28 @@ def locate_files(root: str, mode: str):
105
  file_list += sorted(files)
106
  print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
107
  # ------------------------- IDMT-SMT-BASS -------------------------
108
- idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
109
- if os.path.isdir(idmt_smt_bass_dir):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  files = glob.glob(
111
- os.path.join(idmt_smt_bass_dir, "**", "*.wav"),
112
  recursive=True,
113
  )
114
- files = [
115
- f
116
- for f in files
117
- if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
118
- ]
119
  file_list += sorted(files)
120
- print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
121
  # ------------------------- IDMT-SMT-DRUMS -------------------------
122
  idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
123
  if os.path.isdir(idmt_smt_drums_dir):
@@ -133,7 +147,7 @@ def locate_files(root: str, mode: str):
133
  return file_list
134
 
135
 
136
- class VocalSet(Dataset):
137
  def __init__(
138
  self,
139
  root: str,
@@ -199,6 +213,9 @@ class VocalSet(Dataset):
199
  if resampled_chunk.shape[-1] < chunk_size:
200
  # Skip if chunk is too small
201
  continue
 
 
 
202
 
203
  dry, wet, dry_effects, wet_effects = self.process_effects(
204
  resampled_chunk
@@ -334,7 +351,7 @@ class VocalSet(Dataset):
334
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
335
 
336
 
337
- class VocalSetDatamodule(pl.LightningDataModule):
338
  def __init__(
339
  self,
340
  train_dataset,
 
55
  "val": ["VIF"],
56
  "test": ["VIS"],
57
  }
58
+ dsd_100_splits = {
59
+ "train": ["train"],
60
+ "val": ["val"],
61
+ "test": ["test"],
62
+ }
63
  idmt_drums_splits = {
64
  "train": ["WaveDrum02", "TechnoDrum01"],
65
  "val": ["RealDrum01"],
 
110
  file_list += sorted(files)
111
  print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
112
  # ------------------------- IDMT-SMT-BASS -------------------------
113
+ # idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
114
+ # if os.path.isdir(idmt_smt_bass_dir):
115
+ # files = glob.glob(
116
+ # os.path.join(idmt_smt_bass_dir, "**", "*.wav"),
117
+ # recursive=True,
118
+ # )
119
+ # files = [
120
+ # f
121
+ # for f in files
122
+ # if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
123
+ # ]
124
+ # file_list += sorted(files)
125
+ # print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
126
+ # ------------------------- DSD100 ---------------------------------
127
+ dsd_100_dir = os.path.join(root, "DSD100")
128
+ if os.path.isdir(dsd_100_dir):
129
  files = glob.glob(
130
+ os.path.join(dsd_100_dir, mode, "**", "*.wav"),
131
  recursive=True,
132
  )
 
 
 
 
 
133
  file_list += sorted(files)
134
+ print(f"Found {len(files)} files in DSD100 {mode}.")
135
  # ------------------------- IDMT-SMT-DRUMS -------------------------
136
  idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
137
  if os.path.isdir(idmt_smt_drums_dir):
 
147
  return file_list
148
 
149
 
150
+ class EffectDataset(Dataset):
151
  def __init__(
152
  self,
153
  root: str,
 
213
  if resampled_chunk.shape[-1] < chunk_size:
214
  # Skip if chunk is too small
215
  continue
216
+ # Sum to mono
217
+ if resampled_chunk.shape[0] > 1:
218
+ resampled_chunk = resampled_chunk.sum(0, keepdim=True)
219
 
220
  dry, wet, dry_effects, wet_effects = self.process_effects(
221
  resampled_chunk
 
351
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
352
 
353
 
354
+ class EffectDatamodule(pl.LightningDataModule):
355
  def __init__(
356
  self,
357
  train_dataset,
remfx/models.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  import torchmetrics
3
  import pytorch_lightning as pl
4
  from torch import Tensor, nn
5
- from torch.nn import functional as F
6
  from torchaudio.models import HDemucs
7
  from audio_diffusion_pytorch import DiffusionModel
8
  from auraloss.time import SISDRLoss
 
2
  import torchmetrics
3
  import pytorch_lightning as pl
4
  from torch import Tensor, nn
 
5
  from torchaudio.models import HDemucs
6
  from audio_diffusion_pytorch import DiffusionModel
7
  from auraloss.time import SISDRLoss
scripts/download.py CHANGED
@@ -1,8 +1,6 @@
1
  import os
2
- import sys
3
- import glob
4
- import torch
5
  import argparse
 
6
 
7
 
8
  def download_zip_dataset(dataset_url: str, output_dir: str):
@@ -26,8 +24,42 @@ def process_dataset(dataset_dir: str, output_dir: str):
26
  pass
27
  elif dataset_dir == "IDMT-SMT-DRUMS-V2":
28
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  else:
30
- raise NotImplemented(f"Invalid dataset_dir = {dataset_dir}.")
31
 
32
 
33
  if __name__ == "__main__":
@@ -38,7 +70,7 @@ if __name__ == "__main__":
38
  "vocalset",
39
  "guitarset",
40
  "idmt-smt-guitar",
41
- "idmt-smt-bass",
42
  "idmt-smt-drums",
43
  ],
44
  nargs="+",
@@ -49,10 +81,11 @@ if __name__ == "__main__":
49
  "vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
50
  "guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
51
  "IDMT-SMT-GUITAR_V2": "https://zenodo.org/record/7544110/files/IDMT-SMT-GUITAR_V2.zip",
52
- "IDMT-SMT-BASS": "https://zenodo.org/record/7188892/files/IDMT-SMT-BASS.zip",
53
  "IDMT-SMT-DRUMS-V2": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
54
  }
55
 
56
  for dataset_name, dataset_url in dataset_urls.items():
57
  if dataset_name in args.dataset_names:
58
  download_zip_dataset(dataset_url, "~/data/remfx-data")
 
 
1
  import os
 
 
 
2
  import argparse
3
+ import shutil
4
 
5
 
6
  def download_zip_dataset(dataset_url: str, output_dir: str):
 
24
  pass
25
  elif dataset_dir == "IDMT-SMT-DRUMS-V2":
26
  pass
27
+ elif dataset_dir == "DSD100":
28
+ shutil.rmtree(os.path.join(output_dir, dataset_dir, "Mixtures"))
29
+ for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Dev")):
30
+ source = os.path.join(output_dir, dataset_dir, "Sources", "Dev", dir)
31
+ shutil.move(source, os.path.join(output_dir, dataset_dir))
32
+ shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Dev"))
33
+ for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Test")):
34
+ source = os.path.join(output_dir, dataset_dir, "Sources", "Test", dir)
35
+ shutil.move(source, os.path.join(output_dir, dataset_dir))
36
+ shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Test"))
37
+ shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources"))
38
+
39
+ os.mkdir(os.path.join(output_dir, dataset_dir, "train"))
40
+ os.mkdir(os.path.join(output_dir, dataset_dir, "val"))
41
+ os.mkdir(os.path.join(output_dir, dataset_dir, "test"))
42
+ files = os.listdir(os.path.join(output_dir, dataset_dir))
43
+
44
+ num = 0
45
+ for dir in files:
46
+ if not os.path.isdir(os.path.join(output_dir, dataset_dir, dir)):
47
+ continue
48
+ if dir == "train" or dir == "val" or dir == "test":
49
+ continue
50
+ source = os.path.join(output_dir, dataset_dir, dir, "bass.wav")
51
+ if num < 80:
52
+ dest = os.path.join(output_dir, dataset_dir, "train", f"{num}.wav")
53
+ elif num < 90:
54
+ dest = os.path.join(output_dir, dataset_dir, "val", f"{num}.wav")
55
+ else:
56
+ dest = os.path.join(output_dir, dataset_dir, "test", f"{num}.wav")
57
+ shutil.move(source, dest)
58
+ shutil.rmtree(os.path.join(output_dir, dataset_dir, dir))
59
+ num += 1
60
+
61
  else:
62
+ raise NotImplementedError(f"Invalid dataset_dir = {dataset_dir}.")
63
 
64
 
65
  if __name__ == "__main__":
 
70
  "vocalset",
71
  "guitarset",
72
  "idmt-smt-guitar",
73
+ "dsd100",
74
  "idmt-smt-drums",
75
  ],
76
  nargs="+",
 
81
  "vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
82
  "guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
83
  "IDMT-SMT-GUITAR_V2": "https://zenodo.org/record/7544110/files/IDMT-SMT-GUITAR_V2.zip",
84
+ "DSD100": "http://liutkus.net/DSD100.zip",
85
  "IDMT-SMT-DRUMS-V2": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
86
  }
87
 
88
  for dataset_name, dataset_url in dataset_urls.items():
89
  if dataset_name in args.dataset_names:
90
  download_zip_dataset(dataset_url, "~/data/remfx-data")
91
+ process_dataset(dataset_name, "~/data/remfx-data")
shell_vars.sh CHANGED
@@ -1,3 +1,3 @@
1
- export DATASET_ROOT="./data/VocalSet"
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"
 
1
+ export DATASET_ROOT="./data/"
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"