import argparse |
def get_default_params(model_name): |
model_name = model_name.lower() |
if "vit" in model_name: |
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} |
else: |
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} |
def parse_args(): |
parser = argparse.ArgumentParser() |
parser.add_argument( |
"--train-data", |
type=str, |
default=None, |
help="Path to csv filewith training data", |
) |
parser.add_argument( |
"--val-data", |
type=str, |
default=None, |
help="Path to csv file with validation data", |
) |
parser.add_argument( |
"--train-num-samples", |
type=int, |
default=None, |
help="Number of samples in dataset. Required for webdataset if not available in info file.", |
) |
parser.add_argument( |
"--val-num-samples", |
type=int, |
default=None, |
help="Number of samples in dataset. Useful for webdataset if not available in info file.", |
) |
parser.add_argument( |
"--dataset-type", |
choices=["webdataset", "csv", "auto"], |
default="auto", |
help="Which type of dataset to process." |
) |
parser.add_argument( |
"--dataset-resampled", |
default=False, |
action="store_true", |
help="Whether to use sampling with replacement for webdataset shard selection." |
) |
parser.add_argument( |
"--csv-separator", |
type=str, |
default="\t", |
help="For csv-like datasets, which separator to use." |
) |
parser.add_argument( |
"--csv-img-key", |
type=str, |
default="filepath", |
help="For csv-like datasets, the name of the key for the image paths." |
) |
parser.add_argument( |
"--csv-caption-key", |
type=str, |
default="title", |
help="For csv-like datasets, the name of the key for the captions." |
) |
parser.add_argument( |
"--imagenet-val", |
type=str, |
default=None, |
help="Path to imagenet val set for conducting zero shot evaluation.", |
) |
parser.add_argument( |
"--imagenet-v2", |
type=str, |
default=None, |
help="Path to imagenet v2 for conducting zero shot evaluation.", |
) |
parser.add_argument( |
"--logs", |
type=str, |
default="./logs/", |
help="Where to store tensorboard logs. Use None to avoid storing logs.", |
) |
parser.add_argument( |
"--log-local", |
action="store_true", |
default=False, |
help="log files on local master, otherwise global master only.", |
) |
parser.add_argument( |
"--name", |
type=str, |
default=None, |
help="Optional identifier for the experiment when storing logs. Otherwise use current time.", |
) |
parser.add_argument( |
"--workers", type=int, default=1, help="Number of dataloader workers per GPU." |
) |
parser.add_argument( |
"--batch-size", type=int, default=64, help="Batch size per GPU." |
) |
parser.add_argument( |
"--epochs", type=int, default=32, help="Number of epochs to train for." |
) |
parser.add_argument("--lr", type=float, default=None, help="Learning rate.") |
parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") |
parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") |
parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") |
parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") |
parser.add_argument( |
"--warmup", type=int, default=10000, help="Number of steps to warmup for." |
) |
parser.add_argument( |
"--use-bn-sync", |
default=False, |
action="store_true", |
help="Whether to use batch norm sync.") |
parser.add_argument( |
"--skip-scheduler", |
action="store_true", |
default=False, |
help="Use this flag to skip the learning rate decay.", |
) |
parser.add_argument( |
"--save-frequency", type=int, default=1, help="How often to save checkpoints." |
) |
parser.add_argument( |
"--save-most-recent", |
action="store_true", |
default=False, |
help="Always save the most recent model trained to epoch_latest.pt.", |
) |
parser.add_argument( |
"--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." |
) |
parser.add_argument( |
"--val-frequency", type=int, default=1, help="How often to run evaluation with val data." |
) |
parser.add_argument( |
"--resume", |
default=None, |
type=str, |
help="path to latest checkpoint (default: none)", |
) |
parser.add_argument( |
"--precision", |
choices=["amp", "fp16", "fp32"], |
default="amp", |
help="Floating point precision." |
) |
parser.add_argument( |
"--model", |
type=str, |
default="RN50", |
help="Name of the vision backbone to use.", |
) |
parser.add_argument( |
"--pretrained", |
default='', |
type=str, |
help="Use a pretrained CLIP model weights with the specified tag or file path.", |
) |
parser.add_argument( |
"--pretrained-image", |
default=False, |
action='store_true', |
help="Load imagenet pretrained weights for image tower backbone if available.", |
) |
parser.add_argument( |
"--lock-image", |
default=False, |
action='store_true', |
help="Lock full image tower by disabling gradients.", |
) |
parser.add_argument( |
"--lock-image-unlocked-groups", |
type=int, |
default=0, |
help="Leave last n image tower layer groups unlocked.", |
) |
parser.add_argument( |
"--lock-image-freeze-bn-stats", |
default=False, |
action='store_true', |
help="Freeze BatchNorm running stats in image tower for any locked layers.", |
) |
parser.add_argument( |
"--grad-checkpointing", |
default=False, |
action='store_true', |
help="Enable gradient checkpointing.", |
) |
parser.add_argument( |
"--local-loss", |
default=False, |
action="store_true", |
help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)" |
) |
parser.add_argument( |
"--gather-with-grad", |
default=False, |
action="store_true", |
help="enable full distributed gradient for feature gather" |
) |
parser.add_argument( |
"--force-quick-gelu", |
default=False, |
action='store_true', |
help="Force use of QuickGELU activation for non-OpenAI transformer models.", |
) |
parser.add_argument( |
"--torchscript", |
default=False, |
action='store_true', |
help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", |
) |
parser.add_argument( |
"--trace", |
default=False, |
action='store_true', |
help="torch.jit.trace the model for inference / eval only", |
) |
parser.add_argument( |
"--dist-url", |
default="env://", |
type=str, |
help="url used to set up distributed training", |
) |
parser.add_argument( |
"--dist-backend", default="nccl", type=str, help="distributed backend" |
) |
parser.add_argument( |
"--report-to", |
default='', |
type=str, |
help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']" |
) |
parser.add_argument( |
"--wandb-notes", |
default='', |
type=str, |
help="Notes if logging with wandb" |
) |
parser.add_argument( |
"--debug", |
default=False, |
action="store_true", |
help="If true, more information is logged." |
) |
parser.add_argument( |
"--copy-codebase", |
default=False, |
action="store_true", |
help="If true, we copy the entire base on the log diretory, and execute from there." |
) |
parser.add_argument( |
"--horovod", |
default=False, |
action="store_true", |
help="Use horovod for distributed training." |
) |
parser.add_argument( |
"--ddp-static-graph", |
default=False, |
action='store_true', |
help="Enable static graph optimization for DDP in PyTorch >= 1.11.", |
) |
parser.add_argument( |
"--no-set-device-rank", |
default=False, |
action="store_true", |
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)." |
) |
parser.add_argument( |
"--seed", type=int, default=0, help="Default random seed." |
) |
parser.add_argument( |
"--norm_gradient_clip", type=float, default=None, help="Gradient clip." |
) |
args = parser.parse_args() |
default_params = get_default_params(args.model) |
for name, val in default_params.items(): |
if getattr(args, name) is None: |
setattr(args, name, val) |
return args |