Mugs / src /optimizer.py
zhoupans's picture
Upload 13 files
3c849be
raw
history blame
No virus
6.61 kB
# Copyright 2022 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
implment some functions for optimizers
"""
import numpy as np
import torch
import utils
def clip_gradients(model, clip):
"""
clip gradient if gradient norm > clip
"""
norms = []
for name, p in model.named_parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
norms.append(param_norm.item())
clip_coef = clip / (param_norm + 1e-6)
if clip_coef < 1:
p.grad.data.mul_(clip_coef)
return norms
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
"""
cancle gradient if epoch > freeze_last_layer
"""
if epoch >= freeze_last_layer:
return
for n, p in model.named_parameters():
if "last_layer" in n:
p.grad = None
def cosine_scheduler(
base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0
):
"""
start_warmup_value to base_value in the first warmup_epochs epochs;
then cosine scheduling base_value to final_value in the remaining epochs-warmup_epochs
"""
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_epochs > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = final_value + 0.5 * (base_value - final_value) * (
1 + np.cos(np.pi * iters / len(iters))
)
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def get_params_groups(model):
"""
divide the parameters into several groups, see below
"""
regularized = []
not_regularized = []
patch_embed = []
patch_embed_not_regularized = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# we do not regularize biases nor Norm parameters
if name.endswith(".bias") or len(param.shape) == 1:
if "patch_embed" in name:
patch_embed_not_regularized.append(param)
else:
not_regularized.append(param)
elif "patch_embed" in name:
patch_embed.append(param)
else:
regularized.append(param)
return [
{"name": "normal_params", "params": regularized},
{"name": "patch_embed", "params": patch_embed},
{
"name": "no_wd",
"params": not_regularized,
"apply_wd": False,
"weight_decay": 0.0,
},
{
"name": "patch_embed_no_wd",
"params": patch_embed_not_regularized,
"apply_wd": False,
"weight_decay": 0.0,
},
]
class LARS(torch.optim.Optimizer):
"""
Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
"""
def __init__(
self,
params,
lr=0,
weight_decay=0,
momentum=0.9,
eta=0.001,
weight_decay_filter=None,
lars_adaptation_filter=None,
):
defaults = dict(
lr=lr,
weight_decay=weight_decay,
momentum=momentum,
eta=eta,
weight_decay_filter=weight_decay_filter,
lars_adaptation_filter=lars_adaptation_filter,
)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for g in self.param_groups:
for p in g["params"]:
dp = p.grad
if dp is None:
continue
if p.ndim != 1:
dp = dp.add(p, alpha=g["weight_decay"])
if p.ndim != 1:
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
one = torch.ones_like(param_norm)
q = torch.where(
param_norm > 0.0,
torch.where(
update_norm > 0, (g["eta"] * param_norm / update_norm), one
),
one,
)
dp = dp.mul(q)
param_state = self.state[p]
if "mu" not in param_state:
param_state["mu"] = torch.zeros_like(p)
mu = param_state["mu"]
mu.mul_(g["momentum"]).add_(dp)
p.add_(mu, alpha=-g["lr"])
def get_optimizer(student, len_dataloader, args):
"""
build an optimizer for training
"""
# ============ preparing optimizer ... ============
params_groups = get_params_groups(student)
if args.optimizer == "adamw":
optimizer = torch.optim.AdamW(params_groups) # to use with ViTs
elif args.optimizer == "sgd":
optimizer = torch.optim.SGD(
params_groups, lr=0, momentum=0.9
) # lr is set by scheduler
elif args.optimizer == "lars":
optimizer = LARS(params_groups) # to use with convnet and large batches
# for mixed precision training
fp16_scaler = None
if args.use_fp16:
fp16_scaler = torch.cuda.amp.GradScaler()
# ============ init schedulers ... ============
lr_schedule = cosine_scheduler(
args.lr
* (args.batch_size_per_gpu * utils.get_world_size())
/ 256.0, # linear scaling rule
args.min_lr,
args.epochs,
len_dataloader,
warmup_epochs=args.warmup_epochs,
)
wd_schedule = cosine_scheduler(
args.weight_decay,
args.weight_decay_end,
args.epochs,
len_dataloader, # len(data_loader),
)
# momentum parameter is increased to 1. during training with a cosine schedule
momentum_schedule = cosine_scheduler(
args.momentum_teacher, 1, args.epochs, len_dataloader
)
print("Loss, optimizer and schedulers ready.")
return optimizer, fp16_scaler, lr_schedule, wd_schedule, momentum_schedule