Spaces:
Runtime error
Runtime error
mattricesound
commited on
Commit
•
bd1743b
1
Parent(s):
e0aa67f
Add dsd100 dataset
Browse files- cfg/config.yaml +4 -4
- cfg/exp/default.yaml +1 -1
- remfx/datasets.py +28 -11
- remfx/models.py +0 -1
- scripts/download.py +39 -6
- shell_vars.sh +1 -1
cfg/config.yaml
CHANGED
@@ -53,9 +53,9 @@ callbacks:
|
|
53 |
_target_: remfx.callbacks.MetricCallback
|
54 |
|
55 |
datamodule:
|
56 |
-
_target_: remfx.datasets.
|
57 |
train_dataset:
|
58 |
-
_target_: remfx.datasets.
|
59 |
sample_rate: ${sample_rate}
|
60 |
root: ${oc.env:DATASET_ROOT}
|
61 |
chunk_size: ${chunk_size}
|
@@ -70,7 +70,7 @@ datamodule:
|
|
70 |
render_files: ${render_files}
|
71 |
render_root: ${render_root}
|
72 |
val_dataset:
|
73 |
-
_target_: remfx.datasets.
|
74 |
sample_rate: ${sample_rate}
|
75 |
root: ${oc.env:DATASET_ROOT}
|
76 |
chunk_size: ${chunk_size}
|
@@ -85,7 +85,7 @@ datamodule:
|
|
85 |
render_files: ${render_files}
|
86 |
render_root: ${render_root}
|
87 |
test_dataset:
|
88 |
-
_target_: remfx.datasets.
|
89 |
sample_rate: ${sample_rate}
|
90 |
root: ${oc.env:DATASET_ROOT}
|
91 |
chunk_size: ${chunk_size}
|
|
|
53 |
_target_: remfx.callbacks.MetricCallback
|
54 |
|
55 |
datamodule:
|
56 |
+
_target_: remfx.datasets.EffectDatamodule
|
57 |
train_dataset:
|
58 |
+
_target_: remfx.datasets.EffectDataset
|
59 |
sample_rate: ${sample_rate}
|
60 |
root: ${oc.env:DATASET_ROOT}
|
61 |
chunk_size: ${chunk_size}
|
|
|
70 |
render_files: ${render_files}
|
71 |
render_root: ${render_root}
|
72 |
val_dataset:
|
73 |
+
_target_: remfx.datasets.EffectDataset
|
74 |
sample_rate: ${sample_rate}
|
75 |
root: ${oc.env:DATASET_ROOT}
|
76 |
chunk_size: ${chunk_size}
|
|
|
85 |
render_files: ${render_files}
|
86 |
render_root: ${render_root}
|
87 |
test_dataset:
|
88 |
+
_target_: remfx.datasets.EffectDataset
|
89 |
sample_rate: ${sample_rate}
|
90 |
root: ${oc.env:DATASET_ROOT}
|
91 |
chunk_size: ${chunk_size}
|
cfg/exp/default.yaml
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# @package _global_
|
2 |
defaults:
|
3 |
-
- override /model:
|
4 |
- override /effects: all
|
5 |
seed: 12345
|
6 |
sample_rate: 48000
|
|
|
1 |
# @package _global_
|
2 |
defaults:
|
3 |
+
- override /model: umx
|
4 |
- override /effects: all
|
5 |
seed: 12345
|
6 |
sample_rate: 48000
|
remfx/datasets.py
CHANGED
@@ -55,6 +55,11 @@ idmt_bass_splits = {
|
|
55 |
"val": ["VIF"],
|
56 |
"test": ["VIS"],
|
57 |
}
|
|
|
|
|
|
|
|
|
|
|
58 |
idmt_drums_splits = {
|
59 |
"train": ["WaveDrum02", "TechnoDrum01"],
|
60 |
"val": ["RealDrum01"],
|
@@ -105,19 +110,28 @@ def locate_files(root: str, mode: str):
|
|
105 |
file_list += sorted(files)
|
106 |
print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
|
107 |
# ------------------------- IDMT-SMT-BASS -------------------------
|
108 |
-
idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
|
109 |
-
if os.path.isdir(idmt_smt_bass_dir):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
files = glob.glob(
|
111 |
-
os.path.join(
|
112 |
recursive=True,
|
113 |
)
|
114 |
-
files = [
|
115 |
-
f
|
116 |
-
for f in files
|
117 |
-
if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
|
118 |
-
]
|
119 |
file_list += sorted(files)
|
120 |
-
print(f"Found {len(files)} files in
|
121 |
# ------------------------- IDMT-SMT-DRUMS -------------------------
|
122 |
idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
|
123 |
if os.path.isdir(idmt_smt_drums_dir):
|
@@ -133,7 +147,7 @@ def locate_files(root: str, mode: str):
|
|
133 |
return file_list
|
134 |
|
135 |
|
136 |
-
class
|
137 |
def __init__(
|
138 |
self,
|
139 |
root: str,
|
@@ -199,6 +213,9 @@ class VocalSet(Dataset):
|
|
199 |
if resampled_chunk.shape[-1] < chunk_size:
|
200 |
# Skip if chunk is too small
|
201 |
continue
|
|
|
|
|
|
|
202 |
|
203 |
dry, wet, dry_effects, wet_effects = self.process_effects(
|
204 |
resampled_chunk
|
@@ -334,7 +351,7 @@ class VocalSet(Dataset):
|
|
334 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
335 |
|
336 |
|
337 |
-
class
|
338 |
def __init__(
|
339 |
self,
|
340 |
train_dataset,
|
|
|
55 |
"val": ["VIF"],
|
56 |
"test": ["VIS"],
|
57 |
}
|
58 |
+
dsd_100_splits = {
|
59 |
+
"train": ["train"],
|
60 |
+
"val": ["val"],
|
61 |
+
"test": ["test"],
|
62 |
+
}
|
63 |
idmt_drums_splits = {
|
64 |
"train": ["WaveDrum02", "TechnoDrum01"],
|
65 |
"val": ["RealDrum01"],
|
|
|
110 |
file_list += sorted(files)
|
111 |
print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
|
112 |
# ------------------------- IDMT-SMT-BASS -------------------------
|
113 |
+
# idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
|
114 |
+
# if os.path.isdir(idmt_smt_bass_dir):
|
115 |
+
# files = glob.glob(
|
116 |
+
# os.path.join(idmt_smt_bass_dir, "**", "*.wav"),
|
117 |
+
# recursive=True,
|
118 |
+
# )
|
119 |
+
# files = [
|
120 |
+
# f
|
121 |
+
# for f in files
|
122 |
+
# if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
|
123 |
+
# ]
|
124 |
+
# file_list += sorted(files)
|
125 |
+
# print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
|
126 |
+
# ------------------------- DSD100 ---------------------------------
|
127 |
+
dsd_100_dir = os.path.join(root, "DSD100")
|
128 |
+
if os.path.isdir(dsd_100_dir):
|
129 |
files = glob.glob(
|
130 |
+
os.path.join(dsd_100_dir, mode, "**", "*.wav"),
|
131 |
recursive=True,
|
132 |
)
|
|
|
|
|
|
|
|
|
|
|
133 |
file_list += sorted(files)
|
134 |
+
print(f"Found {len(files)} files in DSD100 {mode}.")
|
135 |
# ------------------------- IDMT-SMT-DRUMS -------------------------
|
136 |
idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
|
137 |
if os.path.isdir(idmt_smt_drums_dir):
|
|
|
147 |
return file_list
|
148 |
|
149 |
|
150 |
+
class EffectDataset(Dataset):
|
151 |
def __init__(
|
152 |
self,
|
153 |
root: str,
|
|
|
213 |
if resampled_chunk.shape[-1] < chunk_size:
|
214 |
# Skip if chunk is too small
|
215 |
continue
|
216 |
+
# Sum to mono
|
217 |
+
if resampled_chunk.shape[0] > 1:
|
218 |
+
resampled_chunk = resampled_chunk.sum(0, keepdim=True)
|
219 |
|
220 |
dry, wet, dry_effects, wet_effects = self.process_effects(
|
221 |
resampled_chunk
|
|
|
351 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
352 |
|
353 |
|
354 |
+
class EffectDatamodule(pl.LightningDataModule):
|
355 |
def __init__(
|
356 |
self,
|
357 |
train_dataset,
|
remfx/models.py
CHANGED
@@ -2,7 +2,6 @@ import torch
|
|
2 |
import torchmetrics
|
3 |
import pytorch_lightning as pl
|
4 |
from torch import Tensor, nn
|
5 |
-
from torch.nn import functional as F
|
6 |
from torchaudio.models import HDemucs
|
7 |
from audio_diffusion_pytorch import DiffusionModel
|
8 |
from auraloss.time import SISDRLoss
|
|
|
2 |
import torchmetrics
|
3 |
import pytorch_lightning as pl
|
4 |
from torch import Tensor, nn
|
|
|
5 |
from torchaudio.models import HDemucs
|
6 |
from audio_diffusion_pytorch import DiffusionModel
|
7 |
from auraloss.time import SISDRLoss
|
scripts/download.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
import os
|
2 |
-
import sys
|
3 |
-
import glob
|
4 |
-
import torch
|
5 |
import argparse
|
|
|
6 |
|
7 |
|
8 |
def download_zip_dataset(dataset_url: str, output_dir: str):
|
@@ -26,8 +24,42 @@ def process_dataset(dataset_dir: str, output_dir: str):
|
|
26 |
pass
|
27 |
elif dataset_dir == "IDMT-SMT-DRUMS-V2":
|
28 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
else:
|
30 |
-
raise
|
31 |
|
32 |
|
33 |
if __name__ == "__main__":
|
@@ -38,7 +70,7 @@ if __name__ == "__main__":
|
|
38 |
"vocalset",
|
39 |
"guitarset",
|
40 |
"idmt-smt-guitar",
|
41 |
-
"
|
42 |
"idmt-smt-drums",
|
43 |
],
|
44 |
nargs="+",
|
@@ -49,10 +81,11 @@ if __name__ == "__main__":
|
|
49 |
"vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
|
50 |
"guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
|
51 |
"IDMT-SMT-GUITAR_V2": "https://zenodo.org/record/7544110/files/IDMT-SMT-GUITAR_V2.zip",
|
52 |
-
"
|
53 |
"IDMT-SMT-DRUMS-V2": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
|
54 |
}
|
55 |
|
56 |
for dataset_name, dataset_url in dataset_urls.items():
|
57 |
if dataset_name in args.dataset_names:
|
58 |
download_zip_dataset(dataset_url, "~/data/remfx-data")
|
|
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
import argparse
|
3 |
+
import shutil
|
4 |
|
5 |
|
6 |
def download_zip_dataset(dataset_url: str, output_dir: str):
|
|
|
24 |
pass
|
25 |
elif dataset_dir == "IDMT-SMT-DRUMS-V2":
|
26 |
pass
|
27 |
+
elif dataset_dir == "DSD100":
|
28 |
+
shutil.rmtree(os.path.join(output_dir, dataset_dir, "Mixtures"))
|
29 |
+
for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Dev")):
|
30 |
+
source = os.path.join(output_dir, dataset_dir, "Sources", "Dev", dir)
|
31 |
+
shutil.move(source, os.path.join(output_dir, dataset_dir))
|
32 |
+
shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Dev"))
|
33 |
+
for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Test")):
|
34 |
+
source = os.path.join(output_dir, dataset_dir, "Sources", "Test", dir)
|
35 |
+
shutil.move(source, os.path.join(output_dir, dataset_dir))
|
36 |
+
shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Test"))
|
37 |
+
shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources"))
|
38 |
+
|
39 |
+
os.mkdir(os.path.join(output_dir, dataset_dir, "train"))
|
40 |
+
os.mkdir(os.path.join(output_dir, dataset_dir, "val"))
|
41 |
+
os.mkdir(os.path.join(output_dir, dataset_dir, "test"))
|
42 |
+
files = os.listdir(os.path.join(output_dir, dataset_dir))
|
43 |
+
|
44 |
+
num = 0
|
45 |
+
for dir in files:
|
46 |
+
if not os.path.isdir(os.path.join(output_dir, dataset_dir, dir)):
|
47 |
+
continue
|
48 |
+
if dir == "train" or dir == "val" or dir == "test":
|
49 |
+
continue
|
50 |
+
source = os.path.join(output_dir, dataset_dir, dir, "bass.wav")
|
51 |
+
if num < 80:
|
52 |
+
dest = os.path.join(output_dir, dataset_dir, "train", f"{num}.wav")
|
53 |
+
elif num < 90:
|
54 |
+
dest = os.path.join(output_dir, dataset_dir, "val", f"{num}.wav")
|
55 |
+
else:
|
56 |
+
dest = os.path.join(output_dir, dataset_dir, "test", f"{num}.wav")
|
57 |
+
shutil.move(source, dest)
|
58 |
+
shutil.rmtree(os.path.join(output_dir, dataset_dir, dir))
|
59 |
+
num += 1
|
60 |
+
|
61 |
else:
|
62 |
+
raise NotImplementedError(f"Invalid dataset_dir = {dataset_dir}.")
|
63 |
|
64 |
|
65 |
if __name__ == "__main__":
|
|
|
70 |
"vocalset",
|
71 |
"guitarset",
|
72 |
"idmt-smt-guitar",
|
73 |
+
"dsd100",
|
74 |
"idmt-smt-drums",
|
75 |
],
|
76 |
nargs="+",
|
|
|
81 |
"vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
|
82 |
"guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
|
83 |
"IDMT-SMT-GUITAR_V2": "https://zenodo.org/record/7544110/files/IDMT-SMT-GUITAR_V2.zip",
|
84 |
+
"DSD100": "http://liutkus.net/DSD100.zip",
|
85 |
"IDMT-SMT-DRUMS-V2": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
|
86 |
}
|
87 |
|
88 |
for dataset_name, dataset_url in dataset_urls.items():
|
89 |
if dataset_name in args.dataset_names:
|
90 |
download_zip_dataset(dataset_url, "~/data/remfx-data")
|
91 |
+
process_dataset(dataset_name, "~/data/remfx-data")
|
shell_vars.sh
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
export DATASET_ROOT="./data/
|
2 |
export WANDB_PROJECT="RemFX"
|
3 |
export WANDB_ENTITY="mattricesound"
|
|
|
1 |
+
export DATASET_ROOT="./data/"
|
2 |
export WANDB_PROJECT="RemFX"
|
3 |
export WANDB_ENTITY="mattricesound"
|