File size: 9,228 Bytes
8d82201 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import argparse
def get_parser():
parser = argparse.ArgumentParser(description='LAVT training and testing')
parser.add_argument('--amsgrad', action='store_true',
help='if true, set amsgrad to True in an Adam or AdamW optimizer.')
parser.add_argument('-b', '--batch-size', default=8, type=int)
parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer')
parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights')
parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog')
parser.add_argument('--ddp_trained_weights', action='store_true',
help='Only needs specified when testing,'
'whether the weights to be loaded are from a DDP-trained model')
parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine
parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs')
parser.add_argument('--img_size', default=480, type=int, help='input image size')
parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel')
parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate')
parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,'
'where a, b, c, and d refer to the numbers of heads in stage-1,'
'stage-2, stage-3, and stage-4 PWAMs')
parser.add_argument('--model', default='lavt', help='model: lavt, lavt_one')
parser.add_argument('--model_id', default='lavt', help='name to identify the model')
parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights')
parser.add_argument('--pin_mem', action='store_true',
help='If true, pin memory when using the data loader.')
parser.add_argument('--pretrained_swin_weights', default='',
help='path to pre-trained Swin backbone weights')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--split', default='test', help='only used when testing')
parser.add_argument('--splitBy', default='unc', help='change to umd or google when the dataset is G-Ref (RefCOCOg)')
parser.add_argument('--swin_type', default='base',
help='tiny, small, base, or large variants of the Swin Transformer')
parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay',
dest='weight_decay')
parser.add_argument('--window12', action='store_true',
help='only needs specified when testing,'
'when training, window size is inferred from pre-trained weights file name'
'(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.')
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers')
parser.add_argument('--config',
default='path to xxx.yaml',
type=str,
help='config file')
return parser
# -----------------------------------------------------------------------------
# Functions for parsing args
# -----------------------------------------------------------------------------
import copy
import os
from ast import literal_eval
import yaml
class CfgNode(dict):
"""
CfgNode represents an internal node in the configuration tree. It's a simple
dict-like container that allows for attribute-based access to keys.
"""
def __init__(self, init_dict=None, key_list=None, new_allowed=False):
# Recursively convert nested dictionaries in init_dict into CfgNodes
init_dict = {} if init_dict is None else init_dict
key_list = [] if key_list is None else key_list
for k, v in init_dict.items():
if type(v) is dict:
# Convert dict to CfgNode
init_dict[k] = CfgNode(v, key_list=key_list + [k])
super(CfgNode, self).__init__(init_dict)
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
self[name] = value
def __str__(self):
def _indent(s_, num_spaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
r = ""
s = []
for k, v in sorted(self.items()):
seperator = "\n" if isinstance(v, CfgNode) else " "
attr_str = "{}:{}{}".format(str(k), seperator, str(v))
attr_str = _indent(attr_str, 2)
s.append(attr_str)
r += "\n".join(s)
return r
def __repr__(self):
return "{}({})".format(self.__class__.__name__,
super(CfgNode, self).__repr__())
def load_cfg_from_cfg_file(file):
cfg = {}
assert os.path.isfile(file) and file.endswith('.yaml'), \
'{} is not a yaml file'.format(file)
with open(file, 'r') as f:
cfg_from_file = yaml.safe_load(f)
for key in cfg_from_file:
for k, v in cfg_from_file[key].items():
cfg[k] = v
cfg = CfgNode(cfg)
return cfg
def merge_cfg_from_list(cfg, cfg_list):
new_cfg = copy.deepcopy(cfg)
assert len(cfg_list) % 2 == 0
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
subkey = full_key.split('.')[-1]
assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
value = _decode_cfg_value(v)
value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey,
full_key)
setattr(new_cfg, subkey, value)
return new_cfg
def _decode_cfg_value(v):
"""Decodes a raw config value (e.g., from a yaml config files or command
line argument) into a Python object.
"""
# All remaining processing is only applied to strings
if not isinstance(v, str):
return v
# Try to interpret `v` as a:
# string, number, tuple, list, dict, boolean, or None
try:
v = literal_eval(v)
# The following two excepts allow v to pass through when it represents a
# string.
#
# Longer explanation:
# The type of v is always a string (before calling literal_eval), but
# sometimes it *represents* a string and other times a data structure, like
# a list. In the case that v represents a string, what we got back from the
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
# will raise a SyntaxError.
except ValueError:
pass
except SyntaxError:
pass
return v
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
"""Checks that `replacement`, which is intended to replace `original` is of
the right type. The type is correct if it matches exactly or is one of a few
cases in which the type can be easily coerced.
"""
original_type = type(original)
replacement_type = type(replacement)
# The types must match (with some exceptions)
if replacement_type == original_type:
return replacement
# Cast replacement from from_type to to_type if the replacement and original
# types match from_type and to_type
def conditional_cast(from_type, to_type):
if replacement_type == from_type and original_type == to_type:
return True, to_type(replacement)
else:
return False, None
# Conditionally casts
# list <-> tuple
casts = [(tuple, list), (list, tuple)]
# For py2: allow converting from str (bytes) to a unicode string
try:
casts.append((str, unicode)) # noqa: F821
except Exception:
pass
for (from_type, to_type) in casts:
converted, converted_value = conditional_cast(from_type, to_type)
if converted:
return converted_value
raise ValueError(
"Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
"key: {}".format(original_type, replacement_type, original,
replacement, full_key))
if __name__ == "__main__":
parser = get_parser()
args_dict = parser.parse_args()
|