RemFx / scripts /download.py
mattricesound's picture
Update CSV logger
c1b80c0
import os
import argparse
import shutil
def download_zip_dataset(dataset_url: str, output_dir: str):
zip_filename = os.path.basename(dataset_url)
zip_name = zip_filename.replace(".zip", "")
if not os.path.exists(os.path.join(output_dir, zip_name)):
os.system(f"wget -P {output_dir} {dataset_url}")
os.system(
f"""unzip {os.path.join(output_dir, zip_filename)} -d {os.path.join(output_dir, zip_name)}"""
)
os.system(f"rm {os.path.join(output_dir, zip_filename)}")
else:
print(
f"Dataset {zip_name} already downloaded at {output_dir}, skipping download."
)
def process_dataset(dataset_dir: str, output_dir: str):
if dataset_dir == "vocalset":
pass
elif dataset_dir == "guitarset":
pass
elif dataset_dir == "idmt-smt-drums":
pass
elif dataset_dir == "dsd100":
dataset_root_dir = "DSD100/DSD100"
shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Mixtures"))
for dir in os.listdir(
os.path.join(output_dir, dataset_root_dir, "Sources", "Dev")
):
source = os.path.join(output_dir, dataset_root_dir, "Sources", "Dev", dir)
shutil.move(source, os.path.join(output_dir, dataset_root_dir))
shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Sources", "Dev"))
for dir in os.listdir(
os.path.join(output_dir, dataset_root_dir, "Sources", "Test")
):
source = os.path.join(output_dir, dataset_root_dir, "Sources", "Test", dir)
shutil.move(source, os.path.join(output_dir, dataset_root_dir))
shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Sources", "Test"))
shutil.rmtree(os.path.join(output_dir, dataset_root_dir, "Sources"))
os.mkdir(os.path.join(output_dir, dataset_root_dir, "train"))
os.mkdir(os.path.join(output_dir, dataset_root_dir, "val"))
os.mkdir(os.path.join(output_dir, dataset_root_dir, "test"))
files = os.listdir(os.path.join(output_dir, dataset_root_dir))
num = 0
for dir in files:
if not os.path.isdir(os.path.join(output_dir, dataset_root_dir, dir)):
continue
if dir == "train" or dir == "val" or dir == "test":
continue
source = os.path.join(output_dir, dataset_root_dir, dir, "bass.wav")
if num < 80:
dest = os.path.join(output_dir, dataset_root_dir, "train", f"{num}.wav")
elif num < 90:
dest = os.path.join(output_dir, dataset_root_dir, "val", f"{num}.wav")
else:
dest = os.path.join(output_dir, dataset_root_dir, "test", f"{num}.wav")
shutil.move(source, dest)
shutil.rmtree(os.path.join(output_dir, dataset_root_dir, dir))
num += 1
else:
raise NotImplementedError(f"Invalid dataset_dir = {dataset_dir}.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"dataset_names",
choices=[
"vocalset",
"guitarset",
"dsd100",
"idmt-smt-drums",
],
nargs="+",
)
parser.add_argument("--output_dir", default="./data/remfx-data")
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
dataset_urls = {
"vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
"guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
"dsd100": "http://liutkus.net/DSD100.zip",
"idmt-smt-drums": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
}
for dataset_name, dataset_url in dataset_urls.items():
if dataset_name in args.dataset_names:
print("Downloading dataset: ", dataset_name)
download_zip_dataset(dataset_url, args.output_dir)
process_dataset(dataset_name, args.output_dir)