|
|
|
|
|
import itertools |
|
from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union |
|
import torch |
|
|
|
from detectron2.config import CfgNode |
|
|
|
from detectron2.solver.build import maybe_add_gradient_clipping |
|
|
|
|
|
def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: |
|
params: List[Dict[str, Any]] = [] |
|
memo: Set[torch.nn.parameter.Parameter] = set() |
|
optimizer_type = cfg.SOLVER.OPTIMIZER |
|
|
|
for key, value in model.named_parameters(recurse=True): |
|
if not value.requires_grad: |
|
continue |
|
|
|
if value in memo: |
|
continue |
|
memo.add(value) |
|
lr = cfg.SOLVER.BASE_LR |
|
weight_decay = cfg.SOLVER.WEIGHT_DECAY |
|
|
|
if cfg.SOLVER.VIT_LAYER_DECAY: |
|
lr = lr * get_vit_lr_decay_rate(key, cfg.SOLVER.VIT_LAYER_DECAY_RATE, cfg.MODEL.VIT_LAYERS) |
|
|
|
param = {"params": [value], "lr": lr} |
|
if optimizer_type != 'ADAMW': |
|
param['weight_decay'] = weight_decay |
|
params += [param] |
|
|
|
def maybe_add_full_model_gradient_clipping(optim): |
|
|
|
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE |
|
enable = ( |
|
cfg.SOLVER.CLIP_GRADIENTS.ENABLED |
|
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" |
|
and clip_norm_val > 0.0 |
|
) |
|
|
|
class FullModelGradientClippingOptimizer(optim): |
|
def step(self, closure=None): |
|
all_params = itertools.chain(*[x["params"] for x in self.param_groups]) |
|
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) |
|
super().step(closure=closure) |
|
|
|
return FullModelGradientClippingOptimizer if enable else optim |
|
|
|
|
|
if optimizer_type == 'SGD': |
|
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( |
|
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, |
|
nesterov=cfg.SOLVER.NESTEROV |
|
) |
|
elif optimizer_type == 'ADAMW': |
|
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( |
|
params, cfg.SOLVER.BASE_LR, |
|
weight_decay=cfg.SOLVER.WEIGHT_DECAY |
|
) |
|
else: |
|
raise NotImplementedError(f"no optimizer type {optimizer_type}") |
|
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": |
|
optimizer = maybe_add_gradient_clipping(cfg, optimizer) |
|
return optimizer |
|
|
|
|
|
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12): |
|
""" |
|
Calculate lr decay rate for different ViT blocks. |
|
Args: |
|
name (string): parameter name. |
|
lr_decay_rate (float): base lr decay rate. |
|
num_layers (int): number of ViT blocks. |
|
|
|
Returns: |
|
lr decay rate for the given parameter. |
|
""" |
|
layer_id = num_layers + 1 |
|
if name.startswith("backbone"): |
|
if ".pos_embed" in name or ".patch_embed" in name: |
|
layer_id = 0 |
|
elif ".blocks." in name and ".residual." not in name: |
|
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 |
|
|
|
return lr_decay_rate ** (num_layers + 1 - layer_id) |