# Copyright 2022 Garena Online Private Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ implment some functions for optimizers """ import numpy as np import torch import utils def clip_gradients(model, clip): """ clip gradient if gradient norm > clip """ norms = [] for name, p in model.named_parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) norms.append(param_norm.item()) clip_coef = clip / (param_norm + 1e-6) if clip_coef < 1: p.grad.data.mul_(clip_coef) return norms def cancel_gradients_last_layer(epoch, model, freeze_last_layer): """ cancle gradient if epoch > freeze_last_layer """ if epoch >= freeze_last_layer: return for n, p in model.named_parameters(): if "last_layer" in n: p.grad = None def cosine_scheduler( base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0 ): """ start_warmup_value to base_value in the first warmup_epochs epochs; then cosine scheduling base_value to final_value in the remaining epochs-warmup_epochs """ warmup_schedule = np.array([]) warmup_iters = warmup_epochs * niter_per_ep if warmup_epochs > 0: warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) iters = np.arange(epochs * niter_per_ep - warmup_iters) schedule = final_value + 0.5 * (base_value - final_value) * ( 1 + np.cos(np.pi * iters / len(iters)) ) schedule = np.concatenate((warmup_schedule, schedule)) assert len(schedule) == epochs * niter_per_ep return schedule def get_params_groups(model): """ divide the parameters into several groups, see below """ regularized = [] not_regularized = [] patch_embed = [] patch_embed_not_regularized = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # we do not regularize biases nor Norm parameters if name.endswith(".bias") or len(param.shape) == 1: if "patch_embed" in name: patch_embed_not_regularized.append(param) else: not_regularized.append(param) elif "patch_embed" in name: patch_embed.append(param) else: regularized.append(param) return [ {"name": "normal_params", "params": regularized}, {"name": "patch_embed", "params": patch_embed}, { "name": "no_wd", "params": not_regularized, "apply_wd": False, "weight_decay": 0.0, }, { "name": "patch_embed_no_wd", "params": patch_embed_not_regularized, "apply_wd": False, "weight_decay": 0.0, }, ] class LARS(torch.optim.Optimizer): """ Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py """ def __init__( self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, weight_decay_filter=None, lars_adaptation_filter=None, ): defaults = dict( lr=lr, weight_decay=weight_decay, momentum=momentum, eta=eta, weight_decay_filter=weight_decay_filter, lars_adaptation_filter=lars_adaptation_filter, ) super().__init__(params, defaults) @torch.no_grad() def step(self): for g in self.param_groups: for p in g["params"]: dp = p.grad if dp is None: continue if p.ndim != 1: dp = dp.add(p, alpha=g["weight_decay"]) if p.ndim != 1: param_norm = torch.norm(p) update_norm = torch.norm(dp) one = torch.ones_like(param_norm) q = torch.where( param_norm > 0.0, torch.where( update_norm > 0, (g["eta"] * param_norm / update_norm), one ), one, ) dp = dp.mul(q) param_state = self.state[p] if "mu" not in param_state: param_state["mu"] = torch.zeros_like(p) mu = param_state["mu"] mu.mul_(g["momentum"]).add_(dp) p.add_(mu, alpha=-g["lr"]) def get_optimizer(student, len_dataloader, args): """ build an optimizer for training """ # ============ preparing optimizer ... ============ params_groups = get_params_groups(student) if args.optimizer == "adamw": optimizer = torch.optim.AdamW(params_groups) # to use with ViTs elif args.optimizer == "sgd": optimizer = torch.optim.SGD( params_groups, lr=0, momentum=0.9 ) # lr is set by scheduler elif args.optimizer == "lars": optimizer = LARS(params_groups) # to use with convnet and large batches # for mixed precision training fp16_scaler = None if args.use_fp16: fp16_scaler = torch.cuda.amp.GradScaler() # ============ init schedulers ... ============ lr_schedule = cosine_scheduler( args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256.0, # linear scaling rule args.min_lr, args.epochs, len_dataloader, warmup_epochs=args.warmup_epochs, ) wd_schedule = cosine_scheduler( args.weight_decay, args.weight_decay_end, args.epochs, len_dataloader, # len(data_loader), ) # momentum parameter is increased to 1. during training with a cosine schedule momentum_schedule = cosine_scheduler( args.momentum_teacher, 1, args.epochs, len_dataloader ) print("Loss, optimizer and schedulers ready.") return optimizer, fp16_scaler, lr_schedule, wd_schedule, momentum_schedule