# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse import os from pathlib import Path def get_parser(): parser = argparse.ArgumentParser("demucs", description="Train and evaluate Demucs.") default_raw = None default_musdb = None if 'DEMUCS_RAW' in os.environ: default_raw = Path(os.environ['DEMUCS_RAW']) if 'DEMUCS_MUSDB' in os.environ: default_musdb = Path(os.environ['DEMUCS_MUSDB']) parser.add_argument( "--raw", type=Path, default=default_raw, help="Path to raw audio, can be faster, see python3 -m demucs.raw to extract.") parser.add_argument("--no_raw", action="store_const", const=None, dest="raw") parser.add_argument("-m", "--musdb", type=Path, default=default_musdb, help="Path to musdb root") parser.add_argument("--is_wav", action="store_true", help="Indicate that the MusDB dataset is in wav format (i.e. MusDB-HQ).") parser.add_argument("--metadata", type=Path, default=Path("metadata/"), help="Folder where metadata information is stored.") parser.add_argument("--wav", type=Path, help="Path to a wav dataset. This should contain a 'train' and a 'valid' " "subfolder.") parser.add_argument("--samplerate", type=int, default=44100) parser.add_argument("--audio_channels", type=int, default=2) parser.add_argument("--samples", default=44100 * 10, type=int, help="number of samples to feed in") parser.add_argument("--data_stride", default=44100, type=int, help="Stride for chunks, shorter = longer epochs") parser.add_argument("-w", "--workers", default=10, type=int, help="Loader workers") parser.add_argument("--eval_workers", default=2, type=int, help="Final evaluation workers") parser.add_argument("-d", "--device", help="Device to train on, default is cuda if available else cpu") parser.add_argument("--eval_cpu", action="store_true", help="Eval on test will be run on cpu.") parser.add_argument("--dummy", help="Dummy parameter, useful to create a new checkpoint file") parser.add_argument("--test", help="Just run the test pipeline + one validation. " "This should be a filename relative to the models/ folder.") parser.add_argument("--test_pretrained", help="Just run the test pipeline + one validation, " "on a pretrained model. ") parser.add_argument("--rank", default=0, type=int) parser.add_argument("--world_size", default=1, type=int) parser.add_argument("--master") parser.add_argument("--checkpoints", type=Path, default=Path("checkpoints"), help="Folder where to store checkpoints etc") parser.add_argument("--evals", type=Path, default=Path("evals"), help="Folder where to store evals and waveforms") parser.add_argument("--save", action="store_true", help="Save estimated for the test set waveforms") parser.add_argument("--logs", type=Path, default=Path("logs"), help="Folder where to store logs") parser.add_argument("--models", type=Path, default=Path("models"), help="Folder where to store trained models") parser.add_argument("-R", "--restart", action='store_true', help='Restart training, ignoring previous run') parser.add_argument("--seed", type=int, default=42) parser.add_argument("-e", "--epochs", type=int, default=180, help="Number of epochs") parser.add_argument("-r", "--repeat", type=int, default=2, help="Repeat the train set, longer epochs") parser.add_argument("-b", "--batch_size", type=int, default=64) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--mse", action="store_true", help="Use MSE instead of L1") parser.add_argument("--init", help="Initialize from a pre-trained model.") # Augmentation options parser.add_argument("--no_augment", action="store_false", dest="augment", default=True, help="No basic data augmentation.") parser.add_argument("--repitch", type=float, default=0.2, help="Probability to do tempo/pitch change") parser.add_argument("--max_tempo", type=float, default=12, help="Maximum relative tempo change in %% when using repitch.") parser.add_argument("--remix_group_size", type=int, default=4, help="Shuffle sources using group of this size. Useful to somewhat " "replicate multi-gpu training " "on less GPUs.") parser.add_argument("--shifts", type=int, default=10, help="Number of random shifts used for the shift trick.") parser.add_argument("--overlap", type=float, default=0.25, help="Overlap when --split_valid is passed.") # See model.py for doc parser.add_argument("--growth", type=float, default=2., help="Number of channels between two layers will increase by this factor") parser.add_argument("--depth", type=int, default=6, help="Number of layers for the encoder and decoder") parser.add_argument("--lstm_layers", type=int, default=2, help="Number of layers for the LSTM") parser.add_argument("--channels", type=int, default=64, help="Number of channels for the first encoder layer") parser.add_argument("--kernel_size", type=int, default=8, help="Kernel size for the (transposed) convolutions") parser.add_argument("--conv_stride", type=int, default=4, help="Stride for the (transposed) convolutions") parser.add_argument("--context", type=int, default=3, help="Context size for the decoder convolutions " "before the transposed convolutions") parser.add_argument("--rescale", type=float, default=0.1, help="Initial weight rescale reference") parser.add_argument("--no_resample", action="store_false", default=True, dest="resample", help="No Resampling of the input/output x2") parser.add_argument("--no_glu", action="store_false", default=True, dest="glu", help="Replace all GLUs by ReLUs") parser.add_argument("--no_rewrite", action="store_false", default=True, dest="rewrite", help="No 1x1 rewrite convolutions") parser.add_argument("--normalize", action="store_true") parser.add_argument("--no_norm_wav", action="store_false", dest='norm_wav', default=True) # Tasnet options parser.add_argument("--tasnet", action="store_true") parser.add_argument("--split_valid", action="store_true", help="Predict chunks by chunks for valid and test. Required for tasnet") parser.add_argument("--X", type=int, default=8) # Other options parser.add_argument("--show", action="store_true", help="Show model architecture, size and exit") parser.add_argument("--save_model", action="store_true", help="Skip traning, just save final model " "for the current checkpoint value.") parser.add_argument("--save_state", help="Skip training, just save state " "for the current checkpoint value. You should " "provide a model name as argument.") # Quantization options parser.add_argument("--q-min-size", type=float, default=1, help="Only quantize layers over this size (in MB)") parser.add_argument( "--qat", type=int, help="If provided, use QAT training with that many bits.") parser.add_argument("--diffq", type=float, default=0) parser.add_argument( "--ms-target", type=float, default=162, help="Model size target in MB, when using DiffQ. Best model will be kept " "only if it is smaller than this target.") return parser def get_name(parser, args): """ Return the name of an experiment given the args. Some parameters are ignored, for instance --workers, as they do not impact the final result. """ ignore_args = set([ "checkpoints", "deterministic", "eval", "evals", "eval_cpu", "eval_workers", "logs", "master", "rank", "restart", "save", "save_model", "save_state", "show", "workers", "world_size", ]) parts = [] name_args = dict(args.__dict__) for name, value in name_args.items(): if name in ignore_args: continue if value != parser.get_default(name): if isinstance(value, Path): parts.append(f"{name}={value.name}") else: parts.append(f"{name}={value}") if parts: name = " ".join(parts) else: name = "default" return name