|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import yaml |
|
|
from yacs.config import CfgNode as CN |
|
|
|
|
|
_C = CN() |
|
|
|
|
|
|
|
|
_C.BASE = [''] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_C.DATA = CN() |
|
|
|
|
|
_C.DATA.BATCH_SIZE = 32 |
|
|
|
|
|
_C.DATA.DATA_PATH = '' |
|
|
|
|
|
_C.DATA.DATASET = 'imagenet' |
|
|
|
|
|
_C.DATA.DATASET_ROOT = None |
|
|
|
|
|
_C.DATA.IMG_SIZE = 224 |
|
|
|
|
|
_C.DATA.INTERPOLATION = 'bicubic' |
|
|
_C.DATA.TRAIN_INTERPOLATION = 'bicubic' |
|
|
|
|
|
|
|
|
_C.DATA.ZIP_MODE = False |
|
|
|
|
|
_C.DATA.CACHE_MODE = 'part' |
|
|
|
|
|
_C.DATA.PIN_MEMORY = True |
|
|
|
|
|
_C.DATA.NUM_WORKERS = 4 |
|
|
|
|
|
_C.DATA.TRAIN_PATH = None |
|
|
_C.DATA.VAL_PATH = None |
|
|
|
|
|
_C.DATA.NUM_READERS = 4 |
|
|
|
|
|
|
|
|
|
|
|
_C.DATA.ADD_META = False |
|
|
_C.DATA.FUSION = 'early' |
|
|
_C.DATA.MASK_PROB = 0.0 |
|
|
_C.DATA.MASK_TYPE = 'constant' |
|
|
_C.DATA.LATE_FUSION_LAYER = -1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_C.MODEL = CN() |
|
|
|
|
|
_C.MODEL.TYPE = '' |
|
|
|
|
|
_C.MODEL.NAME = '' |
|
|
|
|
|
_C.MODEL.RESUME = '' |
|
|
|
|
|
_C.MODEL.NUM_CLASSES = 1000 |
|
|
|
|
|
_C.MODEL.DROP_RATE = 0.0 |
|
|
|
|
|
_C.MODEL.DROP_PATH_RATE = 0.1 |
|
|
|
|
|
_C.MODEL.LABEL_SMOOTHING = 0.1 |
|
|
|
|
|
_C.MODEL.PRETRAINED = None |
|
|
_C.MODEL.DORP_HEAD = True |
|
|
_C.MODEL.DORP_META = True |
|
|
_C.MODEL.FREEZE_BACKBONE = True |
|
|
|
|
|
_C.MODEL.ONLY_LAST_CLS = False |
|
|
_C.MODEL.EXTRA_TOKEN_NUM = 1 |
|
|
_C.MODEL.META_DIMS = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_C.TRAIN = CN() |
|
|
_C.TRAIN.START_EPOCH = 0 |
|
|
_C.TRAIN.EPOCHS = 300 |
|
|
_C.TRAIN.WARMUP_EPOCHS = 20 |
|
|
_C.TRAIN.WEIGHT_DECAY = 0.05 |
|
|
_C.TRAIN.BASE_LR = 1e-4 |
|
|
_C.TRAIN.WARMUP_LR = 5e-7 |
|
|
_C.TRAIN.MIN_LR = 1e-5 |
|
|
|
|
|
_C.TRAIN.CLIP_GRAD = 5.0 |
|
|
|
|
|
_C.TRAIN.AUTO_RESUME = True |
|
|
|
|
|
|
|
|
_C.TRAIN.ACCUMULATION_STEPS = 0 |
|
|
|
|
|
|
|
|
_C.TRAIN.USE_CHECKPOINT = False |
|
|
|
|
|
|
|
|
_C.TRAIN.LR_SCHEDULER = CN() |
|
|
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine' |
|
|
|
|
|
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 |
|
|
|
|
|
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 |
|
|
|
|
|
|
|
|
_C.TRAIN.OPTIMIZER = CN() |
|
|
_C.TRAIN.OPTIMIZER.NAME = 'adamw' |
|
|
|
|
|
_C.TRAIN.OPTIMIZER.EPS = 1e-8 |
|
|
|
|
|
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) |
|
|
|
|
|
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_C.AUG = CN() |
|
|
|
|
|
_C.AUG.COLOR_JITTER = 0.4 |
|
|
|
|
|
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' |
|
|
|
|
|
_C.AUG.REPROB = 0.25 |
|
|
|
|
|
_C.AUG.REMODE = 'pixel' |
|
|
|
|
|
_C.AUG.RECOUNT = 1 |
|
|
|
|
|
_C.AUG.MIXUP = 0.8 |
|
|
|
|
|
_C.AUG.CUTMIX = 1.0 |
|
|
|
|
|
_C.AUG.CUTMIX_MINMAX = None |
|
|
|
|
|
_C.AUG.MIXUP_PROB = 1.0 |
|
|
|
|
|
_C.AUG.MIXUP_SWITCH_PROB = 0.5 |
|
|
|
|
|
_C.AUG.MIXUP_MODE = 'batch' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_C.TEST = CN() |
|
|
|
|
|
_C.TEST.CROP = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_C.AMP_OPT_LEVEL = '' |
|
|
|
|
|
_C.OUTPUT = '' |
|
|
|
|
|
_C.TAG = 'default' |
|
|
|
|
|
_C.SAVE_FREQ = 1 |
|
|
|
|
|
_C.PRINT_FREQ = 10 |
|
|
|
|
|
_C.SEED = 0 |
|
|
|
|
|
_C.EVAL_MODE = False |
|
|
|
|
|
_C.THROUGHPUT_MODE = False |
|
|
|
|
|
_C.LOCAL_RANK = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _update_config_from_file(config, cfg_file): |
|
|
config.defrost() |
|
|
with open(cfg_file, 'r') as f: |
|
|
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) |
|
|
|
|
|
for cfg in yaml_cfg.setdefault('BASE', ['']): |
|
|
if cfg: |
|
|
_update_config_from_file( |
|
|
config, os.path.join(os.path.dirname(cfg_file), cfg) |
|
|
) |
|
|
print('=> merge config from {}'.format(cfg_file)) |
|
|
config.merge_from_file(cfg_file) |
|
|
config.freeze() |
|
|
|
|
|
|
|
|
def update_config(config, args): |
|
|
_update_config_from_file(config, args.cfg) |
|
|
|
|
|
config.defrost() |
|
|
if args.opts: |
|
|
config.merge_from_list(args.opts) |
|
|
|
|
|
|
|
|
if args.batch_size: |
|
|
config.DATA.BATCH_SIZE = args.batch_size |
|
|
if args.data_path: |
|
|
config.DATA.DATA_PATH = args.data_path |
|
|
if args.zip: |
|
|
config.DATA.ZIP_MODE = True |
|
|
if args.cache_mode: |
|
|
config.DATA.CACHE_MODE = args.cache_mode |
|
|
if args.resume: |
|
|
config.MODEL.RESUME = args.resume |
|
|
if args.accumulation_steps: |
|
|
config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps |
|
|
if args.use_checkpoint: |
|
|
config.TRAIN.USE_CHECKPOINT = True |
|
|
if args.amp_opt_level: |
|
|
config.AMP_OPT_LEVEL = args.amp_opt_level |
|
|
if args.output: |
|
|
config.OUTPUT = args.output |
|
|
if args.tag: |
|
|
config.TAG = args.tag |
|
|
if args.eval: |
|
|
config.EVAL_MODE = True |
|
|
if args.throughput: |
|
|
config.THROUGHPUT_MODE = True |
|
|
|
|
|
|
|
|
if args.num_workers is not None: |
|
|
config.DATA.NUM_WORKERS = args.num_workers |
|
|
|
|
|
|
|
|
if args.lr is not None: |
|
|
config.TRAIN.BASE_LR = args.lr |
|
|
if args.min_lr is not None: |
|
|
config.TRAIN.MIN_LR = args.min_lr |
|
|
if args.warmup_lr is not None: |
|
|
config.TRAIN.WARMUP_LR = args.warmup_lr |
|
|
if args.warmup_epochs is not None: |
|
|
config.TRAIN.WARMUP_EPOCHS = args.warmup_epochs |
|
|
if args.weight_decay is not None: |
|
|
config.TRAIN.WEIGHT_DECAY = args.weight_decay |
|
|
|
|
|
if args.epochs is not None: |
|
|
config.TRAIN.EPOCHS = args.epochs |
|
|
if args.dataset is not None: |
|
|
config.DATA.DATASET = args.dataset |
|
|
if args.lr_scheduler_name is not None: |
|
|
config.TRAIN.LR_SCHEDULER.NAME = args.lr_scheduler_name |
|
|
if args.pretrain is not None: |
|
|
config.MODEL.PRETRAINED = args.pretrain |
|
|
|
|
|
|
|
|
config.LOCAL_RANK = os.environ['LOCAL_RANK'] |
|
|
|
|
|
|
|
|
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) |
|
|
|
|
|
config.freeze() |
|
|
|
|
|
|
|
|
def get_config(args): |
|
|
"""Get a yacs CfgNode object with default values.""" |
|
|
|
|
|
|
|
|
config = _C.clone() |
|
|
update_config(config, args) |
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
|
|
|
def update_inference_config(config, args): |
|
|
_update_config_from_file(config, args.cfg) |
|
|
|
|
|
config.defrost() |
|
|
|
|
|
config.freeze() |
|
|
|
|
|
|
|
|
def get_inference_config(cfg_path): |
|
|
"""Get a yacs CfgNode object with default values.""" |
|
|
|
|
|
|
|
|
config = _C.clone() |
|
|
update_inference_config(config, cfg_path) |
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
|