import argparse import os CACHE_DIR = os.getenv( "AUDIOLDM_CACHE_DIR", "~/.cache") def get_default_params(model_name): # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) model_name = model_name.lower() if "vit" in model_name: return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} else: return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--train-data", type=str, default=None, help="Path to h5 filewith training data", ) parser.add_argument( "--val-data", type=str, default=None, help="Path to h5 file with validation data", ) parser.add_argument( "--freeze-text", default=False, action="store_true", help="if you need to freeze the text encoder, make this True", ) parser.add_argument( "--freeze-text-after", type=int, default=-1, help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it", ) parser.add_argument( "--train-ipc", type=str, default=None, help="Path to npy file of the number of instance per class in training data", ) parser.add_argument( "--val-ipc", type=str, default=None, help="Path to npy file of the number of instance per class in validation data", ) parser.add_argument( "--train-num-samples", type=int, default=None, help="Number of samples in dataset. Required for webdataset if not available in info file.", ) parser.add_argument( "--val-num-samples", type=int, default=None, help="Number of samples in dataset. Useful for webdataset if not available in info file.", ) parser.add_argument( "--dataset-type", choices=["webdataset", "csv", "auto", "toy"], default="auto", help="Which type of dataset to process.", ) parser.add_argument( "--csv-separator", type=str, default="\t", help="For csv-like datasets, which separator to use.", ) parser.add_argument( "--csv-img-key", type=str, default="filepath", help="For csv-like datasets, the name of the key for the image paths.", ) parser.add_argument( "--csv-caption-key", type=str, default="title", help="For csv-like datasets, the name of the key for the captions.", ) parser.add_argument( "--imagenet-val", type=str, default=None, help="Path to imagenet val set for conducting zero shot evaluation.", ) parser.add_argument( "--imagenet-v2", type=str, default=None, help="Path to imagenet v2 for conducting zero shot evaluation.", ) parser.add_argument( "--datasetnames", nargs="+", default=None, help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects", ) parser.add_argument( "--full-train-dataset", nargs="+", default=None, help="Which dataset will be trained with all the subsets. (train+test)", ) parser.add_argument( "--exclude-eval-dataset", nargs="+", default=None, help="Which dataset will be excluded with evaluation", ) parser.add_argument( "--datasetinfos", nargs="+", default=None, help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval", ) parser.add_argument( "--dataset-proportion", type=float, default=1.0, help="How much proportion of dataset we want to train.", ) parser.add_argument( "--remotedata", default=False, action="store_true", help="if the dataset is remote, set this flag", ) parser.add_argument( "--class-label-path", type=str, default=None, help="The path of the class label pickle or csv.", ) parser.add_argument( "--datasetpath", type=str, default="/mnt/audio_clip/webdataset_tar", help="The path to the dataset", ) parser.add_argument( "--logs", type=str, default="./logs/", help="Where to store tensorboard logs. Use None to avoid storing logs.", ) parser.add_argument( "--log-local", action="store_true", default=False, help="log files on local master, otherwise global master only.", ) parser.add_argument( "--name", type=str, default=None, help="Optional identifier for the experiment when storing logs. Otherwise use current time.", ) parser.add_argument( "--workers", type=int, default=1, help="Number of workers per GPU." ) parser.add_argument( "--batch-size", type=int, default=64, help="Batch size per GPU." ) parser.add_argument( "--epochs", type=int, default=32, help="Number of epochs to train for." ) parser.add_argument("--lr", type=float, default=None, help="Learning rate.") parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.") parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") parser.add_argument( "--split-opt", action="store_true", default=False, help="Use this flag to skip the learning rate decay.", ) parser.add_argument( "--lr-pretrained", type=float, default=None, help="Learning rate for text." ) parser.add_argument( "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text." ) parser.add_argument( "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text." ) parser.add_argument( "--eps-pretrained", type=float, default=None, help="Adam epsilon for text." ) parser.add_argument( "--wd-pretrained", type=float, default=0.2, help="Weight decay for text." ) parser.add_argument( "--momentum-pretrained", type=float, default=0.9, help="Momentum for text." ) parser.add_argument( "--lr-new", type=float, default=None, help="Learning rate for audio." ) parser.add_argument( "--beta1-new", type=float, default=None, help="Adam beta 1 for audio." ) parser.add_argument( "--beta2-new", type=float, default=None, help="Adam beta 2 for audio." ) parser.add_argument( "--eps-new", type=float, default=None, help="Adam epsilon for audio." ) parser.add_argument( "--wd-new", type=float, default=0.2, help="Weight decay for audio." ) parser.add_argument( "--momentum-new", type=float, default=0.9, help="Momentum for audio." ) parser.add_argument( "--warmup", type=int, default=10000, help="Number of steps to warmup for." ) parser.add_argument( "--use-bn-sync", default=False, action="store_true", help="Whether to use batch norm sync.", ) parser.add_argument( "--skip-scheduler", action="store_true", default=False, help="Use this flag to skip the learning rate decay.", ) parser.add_argument( "--save-frequency", type=int, default=1, help="How often to save checkpoints." ) parser.add_argument( "--save-top-performance", type=int, default=0, help="Save the top x performance weights if the value >0", ) parser.add_argument( "--save-most-recent", action="store_true", default=False, help="Always save the most recent model trained to epoch_latest.pt.", ) parser.add_argument( "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." ) parser.add_argument( "--val-frequency", type=int, default=1, help="How often to run evaluation with val data.", ) parser.add_argument( "--resume", default=None, type=str, help="path to latest checkpoint (default: none)", ) parser.add_argument( "--precision", choices=["amp", "fp16", "fp32"], default="amp", help="Floating point precision.", ) parser.add_argument( "--amodel", type=str, default="RN50", help="Name of the audio backbone to use.", ) parser.add_argument( "--tmodel", type=str, default="transformer", help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]", ) parser.add_argument( "--pretrained-audio", default="", type=str, help="Use a pretrained audio model weights for the audio encoder of CLAP", ) parser.add_argument( "--pretrained-text", default="", type=str, help="Use a pretrained text model weights for the text encoder of CLAP", ) parser.add_argument( "--pretrained", default="", type=str, help="Use a pretrained CLIP model weights with the specified tag or file path.", ) parser.add_argument( "--pretrained-image", default=False, action="store_true", help="Load imagenet pretrained weights for image tower backbone if available.", ) parser.add_argument( "--lock-image", default=False, action="store_true", help="Lock full image tower by disabling gradients.", ) parser.add_argument( "--lock-image-unlocked-groups", type=int, default=0, help="Leave last n image tower layer groups unlocked.", ) parser.add_argument( "--lock-image-freeze-bn-stats", default=False, action="store_true", help="Freeze BatchNorm running stats in image tower for any locked layers.", ) parser.add_argument( "--local-loss", default=False, action="store_true", help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)", ) parser.add_argument( "--gather-with-grad", default=False, action="store_true", help="enable full distributed gradient for feature gather", ) parser.add_argument( "--force-quick-gelu", default=False, action="store_true", help="Force use of QuickGELU activation for non-OpenAI transformer models.", ) parser.add_argument( "--torchscript", default=False, action="store_true", help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", ) parser.add_argument( "--trace", default=False, action="store_true", help="torch.jit.trace the model for inference / eval only", ) # arguments for distributed training parser.add_argument( "--dist-url", default="env://", type=str, help="url used to set up distributed training", ) parser.add_argument( "--dist-backend", default="nccl", type=str, help="distributed backend" ) parser.add_argument( "--report-to", default="", type=str, help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']", ) parser.add_argument( "--wandb-notes", default="", type=str, help="Notes if logging with wandb" ) parser.add_argument( "--C", type=float, default=3.16, help="inverse regularizer for logistic reg." ) parser.add_argument( "--debug", default=False, action="store_true", help="If true, more information is logged.", ) parser.add_argument( "--copy-codebase", default=False, action="store_true", help="If true, we copy the entire base on the log diretory, and execute from there.", ) parser.add_argument( "--horovod", default=False, action="store_true", help="Use horovod for distributed training.", ) parser.add_argument( "--ddp-static-graph", default=False, action="store_true", help="Enable static graph optimization for DDP in PyTorch >= 1.11.", ) parser.add_argument( "--no-set-device-rank", default=False, action="store_true", help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", ) parser.add_argument("--seed", type=int, default=4242, help="Default random seed.") parser.add_argument( "--top-k-checkpoint-select-dataset", type=str, default="all", help="The dataset of selecting top-k checkpoint.", ) # @R10, @R@5, @R1, mAP@10 parser.add_argument( "--top-k-checkpoint-select-metric", type=str, default="_R@10", help="The metric for selecting top-k checkpoint.", ) parser.add_argument( "--openai-model-cache-dir", type=str, default=f"{CACHE_DIR}/clip", help="Directory to download OpenAI models.", ) parser.add_argument( "--optimizer", type=str, default="adamw", help="can be AdamW or SGD", ) parser.add_argument( "--parallel-eval", default=False, action="store_true", help="Eval in parallel (multi-GPU, multi-node).", ) parser.add_argument( "--no-eval", default=False, action="store_true", help="Training without evaluation.", ) parser.add_argument( "--lp-mlp", default=False, action="store_true", help="Linear Probe using MLP layer or not.", ) parser.add_argument( "--lp-freeze", default=False, action="store_true", help="Linear Probe using Freeze CLAP or not", ) parser.add_argument( "--lp-act", default="None", type=str, help="Options are ['relu','elu','prelu','softmax','sigmoid']", ) parser.add_argument( "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe." ) parser.add_argument( "--lp-metrics", type=str, default="map,mauc,acc", help="Metrics of Linear Probe.", ) parser.add_argument( "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe" ) parser.add_argument( "--kappa", type=float, default=0, help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss", ) parser.add_argument( "--data-filling", type=str, default="pad", help="type of data filling when the audio length is shorter than the max length." "Can be one of the following: repeat, repeatpad, pad", ) parser.add_argument( "--data-truncating", type=str, default="rand_trunc", help="type of data truncation when the audio length is longer than the max length." "Can be one of the following: rand_trunc, fusion", ) parser.add_argument( "--clap-mlploss", default=False, action="store_true", help="Using MLP loss for CLAP model or not", ) parser.add_argument( "--wandb-id", type=str, default=None, help="the id of wandb experiment to restore.", ) parser.add_argument( "--sleep", type=float, default=0, help="sleep n seconds before start training" ) # variable length processing parser.add_argument( "--enable-fusion", default=False, action="store_true", help="Enable feature funsion for variable-length data", ) parser.add_argument( "--fusion-type", type=str, default="None", help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']", ) parser.add_argument( "--mixup", default=False, action="store_true", help="Enable mixup in finetuning training.", ) parser.add_argument( "--text-augment-selection", type=str, default=None, help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']", ) args = parser.parse_args() # If some params are not passed, we use the default values based on model name. default_params = get_default_params(args.amodel) for name, val in default_params.items(): if getattr(args, name) is None: setattr(args, name, val) return args