""" Optimizer Factory w/ Custom Weight Decay Hacked together by / Copyright 2020 Ross Wightman """ import re import torch from torch import optim as optim from utils.distributed import is_main_process import logging logger = logging.getLogger(__name__) try: from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD has_apex = True except ImportError: has_apex = False def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True): named_param_tuples = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")): named_param_tuples.append([name, param, 0]) elif name in no_decay_list: named_param_tuples.append([name, param, 0]) else: named_param_tuples.append([name, param, weight_decay]) return named_param_tuples def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr): """use lr=diff_lr for modules named found in diff_lr_names, otherwise use lr=default_lr Args: named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module diff_lr_names: List(str) diff_lr: float default_lr: float Returns: named_param_tuples_with_lr: List([name, param, weight_decay, lr]) """ named_param_tuples_with_lr = [] logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}") for name, p, wd in named_param_tuples_or_model: use_diff_lr = False for diff_name in diff_lr_names: # if diff_name in name: if re.search(diff_name, name) is not None: logger.info(f"param {name} use different_lr: {diff_lr}") use_diff_lr = True break named_param_tuples_with_lr.append( [name, p, wd, diff_lr if use_diff_lr else default_lr] ) if is_main_process(): for name, _, wd, diff_lr in named_param_tuples_with_lr: logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}") return named_param_tuples_with_lr def create_optimizer_params_group(named_param_tuples_with_lr): """named_param_tuples_with_lr: List([name, param, weight_decay, lr])""" group = {} for name, p, wd, lr in named_param_tuples_with_lr: if wd not in group: group[wd] = {} if lr not in group[wd]: group[wd][lr] = [] group[wd][lr].append(p) optimizer_params_group = [] for wd, lr_groups in group.items(): for lr, p in lr_groups.items(): optimizer_params_group.append(dict( params=p, weight_decay=wd, lr=lr )) logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}") return optimizer_params_group def create_optimizer(args, model, filter_bias_and_bn=True): opt_lower = args.opt.lower() weight_decay = args.weight_decay # check for modules that requires different lr if hasattr(args, "different_lr") and args.different_lr.enable: diff_lr_module_names = args.different_lr.module_names diff_lr = args.different_lr.lr else: diff_lr_module_names = [] diff_lr = None no_decay = {} if hasattr(model, 'no_weight_decay'): no_decay = model.no_weight_decay() named_param_tuples = add_weight_decay( model, weight_decay, no_decay, filter_bias_and_bn) named_param_tuples = add_different_lr( named_param_tuples, diff_lr_module_names, diff_lr, args.lr) parameters = create_optimizer_params_group(named_param_tuples) if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' opt_args = dict(lr=args.lr, weight_decay=weight_decay) if hasattr(args, 'opt_eps') and args.opt_eps is not None: opt_args['eps'] = args.opt_eps if hasattr(args, 'opt_betas') and args.opt_betas is not None: opt_args['betas'] = args.opt_betas if hasattr(args, 'opt_args') and args.opt_args is not None: opt_args.update(args.opt_args) opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError return optimizer