import math def adjust_learning_rate(optimizer, base_lr, p, itr, max_itr, restart=1, warm_up_steps=1000, is_cosine_decay=False, min_lr=1e-5, encoder_lr_ratio=1.0, freeze_params=[]): if restart > 1: each_max_itr = int(math.ceil(float(max_itr) / restart)) itr = itr % each_max_itr warm_up_steps /= restart max_itr = each_max_itr if itr < warm_up_steps: now_lr = min_lr + (base_lr - min_lr) * itr / warm_up_steps else: itr = itr - warm_up_steps max_itr = max_itr - warm_up_steps if is_cosine_decay: now_lr = min_lr + (base_lr - min_lr) * (math.cos(math.pi * itr / (max_itr + 1)) + 1.) * 0.5 else: now_lr = min_lr + (base_lr - min_lr) * (1 - itr / (max_itr + 1))**p for param_group in optimizer.param_groups: if encoder_lr_ratio != 1.0 and "encoder." in param_group["name"]: param_group['lr'] = (now_lr - min_lr) * encoder_lr_ratio + min_lr else: param_group['lr'] = now_lr for freeze_param in freeze_params: if freeze_param in param_group["name"]: param_group['lr'] = 0 param_group['weight_decay'] = 0 break return now_lr def get_trainable_params(model, base_lr, weight_decay, use_frozen_bn=False, exclusive_wd_dict={}, no_wd_keys=[]): params = [] memo = set() total_param = 0 for key, value in model.named_parameters(): if value in memo: continue total_param += value.numel() if not value.requires_grad: continue memo.add(value) wd = weight_decay for exclusive_key in exclusive_wd_dict.keys(): if exclusive_key in key: wd = exclusive_wd_dict[exclusive_key] break if len(value.shape) == 1: # normalization layers if 'bias' in key: # bias requires no weight decay wd = 0. elif not use_frozen_bn: # if not use frozen BN, apply zero weight decay wd = 0. elif 'encoder.' not in key: # if use frozen BN, apply weight decay to all frozen BNs in the encoder wd = 0. else: for no_wd_key in no_wd_keys: if no_wd_key in key: wd = 0. break params += [{ "params": [value], "lr": base_lr, "weight_decay": wd, "name": key }] print('Total Param: {:.2f}M'.format(total_param / 1e6)) return params def freeze_params(module): for p in module.parameters(): p.requires_grad = False def calculate_params(state_dict): memo = set() total_param = 0 for key, value in state_dict.items(): if value in memo: continue memo.add(value) total_param += value.numel() print('Total Param: {:.2f}M'.format(total_param / 1e6))