Spaces:
Running
on
A10G
Running
on
A10G
import argparse | |
def get_parser(): | |
parser = argparse.ArgumentParser(description='EVP training and testing') | |
parser.add_argument('--amsgrad', action='store_true', | |
help='if true, set amsgrad to True in an Adam or AdamW optimizer.') | |
parser.add_argument('-b', '--batch-size', default=8, type=int) | |
parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights') | |
parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog') | |
parser.add_argument('--ddp_trained_weights', action='store_true', | |
help='Only needs specified when testing,' | |
'whether the weights to be loaded are from a DDP-trained model') | |
parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine | |
parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run') | |
parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs') | |
parser.add_argument('--img_size', default=480, type=int, help='input image size') | |
parser.add_argument("--local_rank", type=int, default=0, help='local rank for DistributedDataParallel') | |
parser.add_argument("--local-rank", type=int, default=0, help='local rank for DistributedDataParallel') | |
parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate') | |
parser.add_argument('--model_id', default='evp', help='name to identify the model') | |
parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights') | |
parser.add_argument('--pin_mem', action='store_true', | |
help='If true, pin memory when using the data loader.') | |
parser.add_argument('--pretrained_swin_weights', default='', | |
help='path to pre-trained Swin backbone weights') | |
parser.add_argument('--print-freq', default=10, type=int, help='print frequency') | |
parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory') | |
parser.add_argument('--resume', default='', help='resume from checkpoint') | |
parser.add_argument('--split', default='val') | |
parser.add_argument('--splitBy', default='unc') | |
parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay', | |
dest='weight_decay') | |
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers') | |
parser.add_argument('--token_length', default=77, type=int) | |
return parser | |
if __name__ == "__main__": | |
parser = get_parser() | |
args_dict = parser.parse_args() | |