|
import argparse |
|
|
|
|
|
def get_default_params(model_name): |
|
|
|
model_name = model_name.lower() |
|
if "vit" in model_name: |
|
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} |
|
else: |
|
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--train-data", |
|
type=str, |
|
default=None, |
|
help="Path to h5 filewith training data", |
|
) |
|
parser.add_argument( |
|
"--val-data", |
|
type=str, |
|
default=None, |
|
help="Path to h5 file with validation data", |
|
) |
|
parser.add_argument( |
|
"--freeze-text", |
|
default=False, |
|
action="store_true", |
|
help="if you need to freeze the text encoder, make this True", |
|
) |
|
parser.add_argument( |
|
"--freeze-text-after", |
|
type=int, |
|
default=-1, |
|
help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it", |
|
) |
|
parser.add_argument( |
|
"--train-ipc", |
|
type=str, |
|
default=None, |
|
help="Path to npy file of the number of instance per class in training data", |
|
) |
|
parser.add_argument( |
|
"--val-ipc", |
|
type=str, |
|
default=None, |
|
help="Path to npy file of the number of instance per class in validation data", |
|
) |
|
parser.add_argument( |
|
"--train-num-samples", |
|
type=int, |
|
default=None, |
|
help="Number of samples in dataset. Required for webdataset if not available in info file.", |
|
) |
|
parser.add_argument( |
|
"--val-num-samples", |
|
type=int, |
|
default=None, |
|
help="Number of samples in dataset. Useful for webdataset if not available in info file.", |
|
) |
|
parser.add_argument( |
|
"--dataset-type", |
|
choices=["webdataset", "csv", "auto", "toy"], |
|
default="auto", |
|
help="Which type of dataset to process.", |
|
) |
|
parser.add_argument( |
|
"--csv-separator", |
|
type=str, |
|
default="\t", |
|
help="For csv-like datasets, which separator to use.", |
|
) |
|
parser.add_argument( |
|
"--csv-img-key", |
|
type=str, |
|
default="filepath", |
|
help="For csv-like datasets, the name of the key for the image paths.", |
|
) |
|
parser.add_argument( |
|
"--csv-caption-key", |
|
type=str, |
|
default="title", |
|
help="For csv-like datasets, the name of the key for the captions.", |
|
) |
|
parser.add_argument( |
|
"--imagenet-val", |
|
type=str, |
|
default=None, |
|
help="Path to imagenet val set for conducting zero shot evaluation.", |
|
) |
|
parser.add_argument( |
|
"--imagenet-v2", |
|
type=str, |
|
default=None, |
|
help="Path to imagenet v2 for conducting zero shot evaluation.", |
|
) |
|
parser.add_argument( |
|
"--datasetnames", |
|
nargs="+", |
|
default=None, |
|
help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects", |
|
) |
|
parser.add_argument( |
|
"--full-train-dataset", |
|
nargs="+", |
|
default=None, |
|
help="Which dataset will be trained with all the subsets. (train+test)", |
|
) |
|
parser.add_argument( |
|
"--exclude-eval-dataset", |
|
nargs="+", |
|
default=None, |
|
help="Which dataset will be excluded with evaluation", |
|
) |
|
parser.add_argument( |
|
"--datasetinfos", |
|
nargs="+", |
|
default=None, |
|
help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval", |
|
) |
|
parser.add_argument( |
|
"--dataset-proportion", |
|
type=float, |
|
default=1.0, |
|
help="How much proportion of dataset we want to train.", |
|
) |
|
parser.add_argument( |
|
"--remotedata", |
|
default=False, |
|
action="store_true", |
|
help="if the dataset is remote, set this flag", |
|
) |
|
parser.add_argument( |
|
"--class-label-path", |
|
type=str, |
|
default=None, |
|
help="The path of the class label pickle or csv.", |
|
) |
|
parser.add_argument( |
|
"--datasetpath", |
|
type=str, |
|
default="/mnt/audio_clip/webdataset_tar", |
|
help="The path to the dataset", |
|
) |
|
parser.add_argument( |
|
"--logs", |
|
type=str, |
|
default="./logs/", |
|
help="Where to store tensorboard logs. Use None to avoid storing logs.", |
|
) |
|
parser.add_argument( |
|
"--log-local", |
|
action="store_true", |
|
default=False, |
|
help="log files on local master, otherwise global master only.", |
|
) |
|
parser.add_argument( |
|
"--name", |
|
type=str, |
|
default=None, |
|
help="Optional identifier for the experiment when storing logs. Otherwise use current time.", |
|
) |
|
parser.add_argument( |
|
"--workers", type=int, default=1, help="Number of workers per GPU." |
|
) |
|
parser.add_argument( |
|
"--batch-size", type=int, default=64, help="Batch size per GPU." |
|
) |
|
parser.add_argument( |
|
"--epochs", type=int, default=32, help="Number of epochs to train for." |
|
) |
|
parser.add_argument("--lr", type=float, default=None, help="Learning rate.") |
|
parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") |
|
parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") |
|
parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") |
|
parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.") |
|
parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") |
|
|
|
parser.add_argument( |
|
"--split-opt", |
|
action="store_true", |
|
default=False, |
|
help="Use this flag to skip the learning rate decay.", |
|
) |
|
parser.add_argument( |
|
"--lr-pretrained", type=float, default=None, help="Learning rate for text." |
|
) |
|
parser.add_argument( |
|
"--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text." |
|
) |
|
parser.add_argument( |
|
"--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text." |
|
) |
|
parser.add_argument( |
|
"--eps-pretrained", type=float, default=None, help="Adam epsilon for text." |
|
) |
|
parser.add_argument( |
|
"--wd-pretrained", type=float, default=0.2, help="Weight decay for text." |
|
) |
|
parser.add_argument( |
|
"--momentum-pretrained", type=float, default=0.9, help="Momentum for text." |
|
) |
|
parser.add_argument( |
|
"--lr-new", type=float, default=None, help="Learning rate for audio." |
|
) |
|
parser.add_argument( |
|
"--beta1-new", type=float, default=None, help="Adam beta 1 for audio." |
|
) |
|
parser.add_argument( |
|
"--beta2-new", type=float, default=None, help="Adam beta 2 for audio." |
|
) |
|
parser.add_argument( |
|
"--eps-new", type=float, default=None, help="Adam epsilon for audio." |
|
) |
|
parser.add_argument( |
|
"--wd-new", type=float, default=0.2, help="Weight decay for audio." |
|
) |
|
parser.add_argument( |
|
"--momentum-new", type=float, default=0.9, help="Momentum for audio." |
|
) |
|
parser.add_argument( |
|
"--warmup", type=int, default=10000, help="Number of steps to warmup for." |
|
) |
|
parser.add_argument( |
|
"--use-bn-sync", |
|
default=False, |
|
action="store_true", |
|
help="Whether to use batch norm sync.", |
|
) |
|
parser.add_argument( |
|
"--skip-scheduler", |
|
action="store_true", |
|
default=False, |
|
help="Use this flag to skip the learning rate decay.", |
|
) |
|
parser.add_argument( |
|
"--save-frequency", type=int, default=1, help="How often to save checkpoints." |
|
) |
|
parser.add_argument( |
|
"--save-top-performance", |
|
type=int, |
|
default=0, |
|
help="Save the top x performance weights if the value >0", |
|
) |
|
parser.add_argument( |
|
"--save-most-recent", |
|
action="store_true", |
|
default=False, |
|
help="Always save the most recent model trained to epoch_latest.pt.", |
|
) |
|
parser.add_argument( |
|
"--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." |
|
) |
|
parser.add_argument( |
|
"--val-frequency", |
|
type=int, |
|
default=1, |
|
help="How often to run evaluation with val data.", |
|
) |
|
parser.add_argument( |
|
"--resume", |
|
default=None, |
|
type=str, |
|
help="path to latest checkpoint (default: none)", |
|
) |
|
parser.add_argument( |
|
"--precision", |
|
choices=["amp", "fp16", "fp32"], |
|
default="amp", |
|
help="Floating point precision.", |
|
) |
|
parser.add_argument( |
|
"--amodel", |
|
type=str, |
|
default="RN50", |
|
help="Name of the audio backbone to use.", |
|
) |
|
parser.add_argument( |
|
"--tmodel", |
|
type=str, |
|
default="transformer", |
|
help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]", |
|
) |
|
parser.add_argument( |
|
"--pretrained-audio", |
|
default="", |
|
type=str, |
|
help="Use a pretrained audio model weights for the audio encoder of CLAP", |
|
) |
|
parser.add_argument( |
|
"--pretrained-text", |
|
default="", |
|
type=str, |
|
help="Use a pretrained text model weights for the text encoder of CLAP", |
|
) |
|
parser.add_argument( |
|
"--pretrained", |
|
default="", |
|
type=str, |
|
help="Use a pretrained CLIP model weights with the specified tag or file path.", |
|
) |
|
parser.add_argument( |
|
"--pretrained-image", |
|
default=False, |
|
action="store_true", |
|
help="Load imagenet pretrained weights for image tower backbone if available.", |
|
) |
|
parser.add_argument( |
|
"--lock-image", |
|
default=False, |
|
action="store_true", |
|
help="Lock full image tower by disabling gradients.", |
|
) |
|
parser.add_argument( |
|
"--lock-image-unlocked-groups", |
|
type=int, |
|
default=0, |
|
help="Leave last n image tower layer groups unlocked.", |
|
) |
|
parser.add_argument( |
|
"--lock-image-freeze-bn-stats", |
|
default=False, |
|
action="store_true", |
|
help="Freeze BatchNorm running stats in image tower for any locked layers.", |
|
) |
|
parser.add_argument( |
|
"--local-loss", |
|
default=False, |
|
action="store_true", |
|
help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)", |
|
) |
|
parser.add_argument( |
|
"--gather-with-grad", |
|
default=False, |
|
action="store_true", |
|
help="enable full distributed gradient for feature gather", |
|
) |
|
parser.add_argument( |
|
"--force-quick-gelu", |
|
default=False, |
|
action="store_true", |
|
help="Force use of QuickGELU activation for non-OpenAI transformer models.", |
|
) |
|
parser.add_argument( |
|
"--torchscript", |
|
default=False, |
|
action="store_true", |
|
help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", |
|
) |
|
parser.add_argument( |
|
"--trace", |
|
default=False, |
|
action="store_true", |
|
help="torch.jit.trace the model for inference / eval only", |
|
) |
|
|
|
parser.add_argument( |
|
"--dist-url", |
|
default="env://", |
|
type=str, |
|
help="url used to set up distributed training", |
|
) |
|
parser.add_argument( |
|
"--dist-backend", default="nccl", type=str, help="distributed backend" |
|
) |
|
parser.add_argument( |
|
"--report-to", |
|
default="", |
|
type=str, |
|
help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']", |
|
) |
|
parser.add_argument( |
|
"--wandb-notes", default="", type=str, help="Notes if logging with wandb" |
|
) |
|
parser.add_argument( |
|
"--C", type=float, default=3.16, help="inverse regularizer for logistic reg." |
|
) |
|
parser.add_argument( |
|
"--debug", |
|
default=False, |
|
action="store_true", |
|
help="If true, more information is logged.", |
|
) |
|
parser.add_argument( |
|
"--copy-codebase", |
|
default=False, |
|
action="store_true", |
|
help="If true, we copy the entire base on the log diretory, and execute from there.", |
|
) |
|
parser.add_argument( |
|
"--horovod", |
|
default=False, |
|
action="store_true", |
|
help="Use horovod for distributed training.", |
|
) |
|
parser.add_argument( |
|
"--ddp-static-graph", |
|
default=False, |
|
action="store_true", |
|
help="Enable static graph optimization for DDP in PyTorch >= 1.11.", |
|
) |
|
parser.add_argument( |
|
"--no-set-device-rank", |
|
default=False, |
|
action="store_true", |
|
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", |
|
) |
|
parser.add_argument("--seed", type=int, default=4242, help="Default random seed.") |
|
|
|
parser.add_argument( |
|
"--top-k-checkpoint-select-dataset", |
|
type=str, |
|
default="all", |
|
help="The dataset of selecting top-k checkpoint.", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--top-k-checkpoint-select-metric", |
|
type=str, |
|
default="_R@10", |
|
help="The metric for selecting top-k checkpoint.", |
|
) |
|
parser.add_argument( |
|
"--openai-model-cache-dir", |
|
type=str, |
|
default="~/.cache/clip", |
|
help="Directory to download OpenAI models.", |
|
) |
|
parser.add_argument( |
|
"--optimizer", |
|
type=str, |
|
default="adamw", |
|
help="can be AdamW or SGD", |
|
) |
|
parser.add_argument( |
|
"--parallel-eval", |
|
default=False, |
|
action="store_true", |
|
help="Eval in parallel (multi-GPU, multi-node).", |
|
) |
|
|
|
parser.add_argument( |
|
"--no-eval", |
|
default=False, |
|
action="store_true", |
|
help="Training without evaluation.", |
|
) |
|
|
|
parser.add_argument( |
|
"--lp-mlp", |
|
default=False, |
|
action="store_true", |
|
help="Linear Probe using MLP layer or not.", |
|
) |
|
|
|
parser.add_argument( |
|
"--lp-freeze", |
|
default=False, |
|
action="store_true", |
|
help="Linear Probe using Freeze CLAP or not", |
|
) |
|
|
|
parser.add_argument( |
|
"--lp-act", |
|
default="None", |
|
type=str, |
|
help="Options are ['relu','elu','prelu','softmax','sigmoid']", |
|
) |
|
|
|
parser.add_argument( |
|
"--lp-loss", type=str, default="bce", help="Loss func of Linear Probe." |
|
) |
|
|
|
parser.add_argument( |
|
"--lp-metrics", |
|
type=str, |
|
default="map,mauc,acc", |
|
help="Metrics of Linear Probe.", |
|
) |
|
|
|
parser.add_argument( |
|
"--lp-lr", type=float, default=1e-4, help="learning rate of linear probe" |
|
) |
|
parser.add_argument( |
|
"--kappa", |
|
type=float, |
|
default=0, |
|
help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss", |
|
) |
|
|
|
parser.add_argument( |
|
"--data-filling", |
|
type=str, |
|
default="pad", |
|
help="type of data filling when the audio length is shorter than the max length." |
|
"Can be one of the following: repeat, repeatpad, pad", |
|
) |
|
parser.add_argument( |
|
"--data-truncating", |
|
type=str, |
|
default="rand_trunc", |
|
help="type of data truncation when the audio length is longer than the max length." |
|
"Can be one of the following: rand_trunc, fusion", |
|
) |
|
|
|
parser.add_argument( |
|
"--clap-mlploss", |
|
default=False, |
|
action="store_true", |
|
help="Using MLP loss for CLAP model or not", |
|
) |
|
|
|
parser.add_argument( |
|
"--wandb-id", |
|
type=str, |
|
default=None, |
|
help="the id of wandb experiment to restore.", |
|
) |
|
|
|
parser.add_argument( |
|
"--sleep", type=float, default=0, help="sleep n seconds before start training" |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--enable-fusion", |
|
default=False, |
|
action="store_true", |
|
help="Enable feature funsion for variable-length data", |
|
) |
|
|
|
parser.add_argument( |
|
"--fusion-type", |
|
type=str, |
|
default="None", |
|
help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']", |
|
) |
|
|
|
parser.add_argument( |
|
"--mixup", |
|
default=False, |
|
action="store_true", |
|
help="Enable mixup in finetuning training.", |
|
) |
|
parser.add_argument( |
|
"--text-augment-selection", |
|
type=str, |
|
default=None, |
|
help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
default_params = get_default_params(args.amodel) |
|
for name, val in default_params.items(): |
|
if getattr(args, name) is None: |
|
setattr(args, name, val) |
|
|
|
return args |
|
|