|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
implment some functions for optimizers |
|
""" |
|
import numpy as np |
|
import torch |
|
|
|
import utils |
|
|
|
|
|
def clip_gradients(model, clip): |
|
""" |
|
clip gradient if gradient norm > clip |
|
""" |
|
norms = [] |
|
for name, p in model.named_parameters(): |
|
if p.grad is not None: |
|
param_norm = p.grad.data.norm(2) |
|
norms.append(param_norm.item()) |
|
clip_coef = clip / (param_norm + 1e-6) |
|
if clip_coef < 1: |
|
p.grad.data.mul_(clip_coef) |
|
return norms |
|
|
|
|
|
def cancel_gradients_last_layer(epoch, model, freeze_last_layer): |
|
""" |
|
cancle gradient if epoch > freeze_last_layer |
|
""" |
|
if epoch >= freeze_last_layer: |
|
return |
|
for n, p in model.named_parameters(): |
|
if "last_layer" in n: |
|
p.grad = None |
|
|
|
|
|
def cosine_scheduler( |
|
base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0 |
|
): |
|
""" |
|
start_warmup_value to base_value in the first warmup_epochs epochs; |
|
then cosine scheduling base_value to final_value in the remaining epochs-warmup_epochs |
|
""" |
|
warmup_schedule = np.array([]) |
|
warmup_iters = warmup_epochs * niter_per_ep |
|
if warmup_epochs > 0: |
|
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) |
|
|
|
iters = np.arange(epochs * niter_per_ep - warmup_iters) |
|
schedule = final_value + 0.5 * (base_value - final_value) * ( |
|
1 + np.cos(np.pi * iters / len(iters)) |
|
) |
|
|
|
schedule = np.concatenate((warmup_schedule, schedule)) |
|
assert len(schedule) == epochs * niter_per_ep |
|
return schedule |
|
|
|
|
|
def get_params_groups(model): |
|
""" |
|
divide the parameters into several groups, see below |
|
""" |
|
regularized = [] |
|
not_regularized = [] |
|
patch_embed = [] |
|
patch_embed_not_regularized = [] |
|
for name, param in model.named_parameters(): |
|
if not param.requires_grad: |
|
continue |
|
|
|
if name.endswith(".bias") or len(param.shape) == 1: |
|
if "patch_embed" in name: |
|
patch_embed_not_regularized.append(param) |
|
else: |
|
not_regularized.append(param) |
|
elif "patch_embed" in name: |
|
patch_embed.append(param) |
|
else: |
|
regularized.append(param) |
|
return [ |
|
{"name": "normal_params", "params": regularized}, |
|
{"name": "patch_embed", "params": patch_embed}, |
|
{ |
|
"name": "no_wd", |
|
"params": not_regularized, |
|
"apply_wd": False, |
|
"weight_decay": 0.0, |
|
}, |
|
{ |
|
"name": "patch_embed_no_wd", |
|
"params": patch_embed_not_regularized, |
|
"apply_wd": False, |
|
"weight_decay": 0.0, |
|
}, |
|
] |
|
|
|
|
|
class LARS(torch.optim.Optimizer): |
|
""" |
|
Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py |
|
""" |
|
|
|
def __init__( |
|
self, |
|
params, |
|
lr=0, |
|
weight_decay=0, |
|
momentum=0.9, |
|
eta=0.001, |
|
weight_decay_filter=None, |
|
lars_adaptation_filter=None, |
|
): |
|
defaults = dict( |
|
lr=lr, |
|
weight_decay=weight_decay, |
|
momentum=momentum, |
|
eta=eta, |
|
weight_decay_filter=weight_decay_filter, |
|
lars_adaptation_filter=lars_adaptation_filter, |
|
) |
|
super().__init__(params, defaults) |
|
|
|
@torch.no_grad() |
|
def step(self): |
|
for g in self.param_groups: |
|
for p in g["params"]: |
|
dp = p.grad |
|
|
|
if dp is None: |
|
continue |
|
|
|
if p.ndim != 1: |
|
dp = dp.add(p, alpha=g["weight_decay"]) |
|
|
|
if p.ndim != 1: |
|
param_norm = torch.norm(p) |
|
update_norm = torch.norm(dp) |
|
one = torch.ones_like(param_norm) |
|
q = torch.where( |
|
param_norm > 0.0, |
|
torch.where( |
|
update_norm > 0, (g["eta"] * param_norm / update_norm), one |
|
), |
|
one, |
|
) |
|
dp = dp.mul(q) |
|
|
|
param_state = self.state[p] |
|
if "mu" not in param_state: |
|
param_state["mu"] = torch.zeros_like(p) |
|
mu = param_state["mu"] |
|
mu.mul_(g["momentum"]).add_(dp) |
|
|
|
p.add_(mu, alpha=-g["lr"]) |
|
|
|
|
|
def get_optimizer(student, len_dataloader, args): |
|
""" |
|
build an optimizer for training |
|
""" |
|
|
|
params_groups = get_params_groups(student) |
|
if args.optimizer == "adamw": |
|
optimizer = torch.optim.AdamW(params_groups) |
|
elif args.optimizer == "sgd": |
|
optimizer = torch.optim.SGD( |
|
params_groups, lr=0, momentum=0.9 |
|
) |
|
elif args.optimizer == "lars": |
|
optimizer = LARS(params_groups) |
|
|
|
fp16_scaler = None |
|
if args.use_fp16: |
|
fp16_scaler = torch.cuda.amp.GradScaler() |
|
|
|
|
|
lr_schedule = cosine_scheduler( |
|
args.lr |
|
* (args.batch_size_per_gpu * utils.get_world_size()) |
|
/ 256.0, |
|
args.min_lr, |
|
args.epochs, |
|
len_dataloader, |
|
warmup_epochs=args.warmup_epochs, |
|
) |
|
wd_schedule = cosine_scheduler( |
|
args.weight_decay, |
|
args.weight_decay_end, |
|
args.epochs, |
|
len_dataloader, |
|
) |
|
|
|
momentum_schedule = cosine_scheduler( |
|
args.momentum_teacher, 1, args.epochs, len_dataloader |
|
) |
|
print("Loss, optimizer and schedulers ready.") |
|
|
|
return optimizer, fp16_scaler, lr_schedule, wd_schedule, momentum_schedule |
|
|